add wust typr mpc and mutipule x

This commit is contained in:
cyy_mac
2026-03-27 03:41:42 +08:00
parent 2c64655fae
commit 7dcb53bb77
192 changed files with 29571 additions and 9 deletions

View File

@@ -0,0 +1,43 @@
add_library(tinympcstatic STATIC
admm.cpp
tiny_api.cpp
codegen.cpp
rho_benchmark.cpp
)
set_property(TARGET tinympcstatic PROPERTY POSITION_INDEPENDENT_CODE ON)
# target_link_libraries(tinympcstatic PUBLIC Eigen)
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include/Eigen)
if(USING_CODEGEN) # Defined in top-level CMakeLists.txt
# Files that are needed for embedded code generation
list( APPEND EMBEDDED_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/admm.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/admm.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/types.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api_constants.hpp" )
foreach( f ${EMBEDDED_FILES} )
get_filename_component( fname ${f} NAME )
set( dest_file "${EMBEDDED_BUILD_TINYMPC_DIR}/${fname}" )
list( APPEND EMBEDDED_BUILD_TINYMPC_FILES "${dest_file}" )
add_custom_command(OUTPUT ${dest_file}
COMMAND ${CMAKE_COMMAND} -E copy "${f}" "${dest_file}"
DEPENDS ${f}
COMMENT "Copying ${fname}")
endforeach()
add_custom_target( copy_codegen_tinympc_files DEPENDS ${EMBEDDED_BUILD_TINYMPC_FILES} )
add_dependencies( copy_codegen_files copy_codegen_tinympc_files )
endif(USING_CODEGEN)

View File

@@ -0,0 +1,399 @@
#include <iostream>
#include "admm.hpp"
#include "rho_benchmark.hpp"
#define DEBUG_MODULE "TINYALG"
extern "C" {
/**
* Update linear terms from Riccati backward pass
*/
void backward_pass_grad(TinySolver* solver) {
for (int i = solver->work->N - 2; i >= 0; i--) {
(solver->work->d.col(i)).noalias() = solver->cache->Quu_inv
* (solver->work->Bdyn.transpose() * solver->work->p.col(i + 1) + solver->work->r.col(i)
+ solver->cache->BPf);
(solver->work->p.col(i)).noalias() = solver->work->q.col(i)
+ solver->cache->AmBKt.lazyProduct(solver->work->p.col(i + 1))
- (solver->cache->Kinf.transpose()).lazyProduct(solver->work->r.col(i))
+ solver->cache->APf;
}
}
/**
* Use LQR feedback policy to roll out trajectory
*/
void forward_pass(TinySolver* solver) {
for (int i = 0; i < solver->work->N - 1; i++) {
(solver->work->u.col(i)).noalias() =
-solver->cache->Kinf.lazyProduct(solver->work->x.col(i)) - solver->work->d.col(i);
(solver->work->x.col(i + 1)).noalias() =
solver->work->Adyn.lazyProduct(solver->work->x.col(i))
+ solver->work->Bdyn.lazyProduct(solver->work->u.col(i)) + solver->work->fdyn;
}
}
/**
* Project a vector s onto the second order cone defined by mu
* @param s, mu
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
*/
tinyVector project_soc(tinyVector s, float mu) {
tinytype u0 = s(Eigen::placeholders::last) * mu;
tinyVector u1 = s.head(s.rows() - 1);
float a = u1.norm();
tinyVector cone_origin(s.rows());
cone_origin.setZero();
if (a <= -u0) { // below cone
return cone_origin;
} else if (a <= u0) { // in cone
return s;
} else if (a >= abs(u0)) { // outside cone
Matrix<tinytype, 3, 1> u2(u1.size() + 1);
u2 << u1, a / mu;
return 0.5 * (1 + u0 / a) * u2;
} else {
return cone_origin;
}
}
/**
* Project a vector z onto a hyperplane defined by a^T z = b
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ b)/||a||² * a
* @param z Vector to project
* @param a Normal vector of the hyperplane
* @param b Offset of the hyperplane
* @return Projection of z onto the hyperplane
*/
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b) {
tinytype dist = (a.dot(z) - b) / a.squaredNorm();
return z - dist * a;
}
/**
* Project slack (auxiliary) variables into their feasible domain, defined by
* projection functions related to each constraint
* TODO: pass in meta information with each constraint assigning it to a
* projection function
*/
void update_slack(TinySolver* solver) {
// Update bound constraint slack variables for state
solver->work->vnew = solver->work->x + solver->work->g;
// Update bound constraint slack variables for input
solver->work->znew = solver->work->u + solver->work->y;
// Box constraints on state
if (solver->settings->en_state_bound) {
solver->work->vnew =
solver->work->x_max.cwiseMin(solver->work->x_min.cwiseMax(solver->work->vnew));
}
// Box constraints on input
if (solver->settings->en_input_bound) {
solver->work->znew =
solver->work->u_max.cwiseMin(solver->work->u_min.cwiseMax(solver->work->znew));
}
// Update second order cone slack variables for state
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->vcnew = solver->work->x + solver->work->gc;
}
// Update second order cone slack variables for input
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->zcnew = solver->work->u + solver->work->yc;
}
// Cone constraints on state
if (solver->settings->en_state_soc) {
for (int i = 0; i < solver->work->N; i++) {
for (int k = 0; k < solver->work->numStateCones; k++) {
int start = solver->work->Acx(k);
int num_xs = solver->work->qcx(k);
tinytype mu = solver->work->cx(k);
tinyVector col = solver->work->vcnew.block(start, i, num_xs, 1);
solver->work->vcnew.block(start, i, num_xs, 1) = project_soc(col, mu);
}
}
}
// Cone constraints on input
if (solver->settings->en_input_soc) {
for (int i = 0; i < solver->work->N - 1; i++) {
for (int k = 0; k < solver->work->numInputCones; k++) {
int start = solver->work->Acu(k);
int num_us = solver->work->qcu(k);
tinytype mu = solver->work->cu(k);
tinyVector col = solver->work->zcnew.block(start, i, num_us, 1);
solver->work->zcnew.block(start, i, num_us, 1) = project_soc(col, mu);
}
}
}
// Update linear constraint slack variables for state
if (solver->settings->en_state_linear) {
solver->work->vlnew = solver->work->x + solver->work->gl;
}
// Update linear constraint slack variables for input
if (solver->settings->en_input_linear) {
solver->work->zlnew = solver->work->u + solver->work->yl;
}
// Linear constraints on state
if (solver->settings->en_state_linear) {
for (int i = 0; i < solver->work->N; i++) {
for (int k = 0; k < solver->work->numStateLinear; k++) {
tinyVector a = solver->work->Alin_x.row(k);
tinytype b = solver->work->blin_x(k);
tinytype constraint_value = a.dot(solver->work->vlnew.col(i));
if (constraint_value > b) { // Only project if constraint is violated
solver->work->vlnew.col(i) =
project_hyperplane(solver->work->vlnew.col(i), a, b);
}
}
}
}
// Linear constraints on input
if (solver->settings->en_input_linear) {
for (int i = 0; i < solver->work->N - 1; i++) {
for (int k = 0; k < solver->work->numInputLinear; k++) {
tinyVector a = solver->work->Alin_u.row(k);
tinytype b = solver->work->blin_u(k);
tinytype constraint_value = a.dot(solver->work->zlnew.col(i));
if (constraint_value > b) { // Only project if constraint is violated
solver->work->zlnew.col(i) =
project_hyperplane(solver->work->zlnew.col(i), a, b);
}
}
}
}
}
/**
* Update next iteration of dual variables by performing the augmented
* lagrangian multiplier update
*/
void update_dual(TinySolver* solver) {
// Update bound constraint dual variables for state
solver->work->g = solver->work->g + solver->work->x - solver->work->vnew;
// Update bound constraint dual variables for input
solver->work->y = solver->work->y + solver->work->u - solver->work->znew;
// Update second order cone dual variables for state
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->gc = solver->work->gc + solver->work->x - solver->work->vcnew;
}
// Update second order cone dual variables for input
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->yc = solver->work->yc + solver->work->u - solver->work->zcnew;
}
// Update linear constraint dual variables for state
if (solver->settings->en_state_linear) {
solver->work->gl = solver->work->gl + solver->work->x - solver->work->vlnew;
}
// Update linear constraint dual variables for input
if (solver->settings->en_input_linear) {
solver->work->yl = solver->work->yl + solver->work->u - solver->work->zlnew;
}
}
/**
* Update linear control cost terms in the Riccati feedback using the changing
* slack and dual variables from ADMM
*/
void update_linear_cost(TinySolver* solver) {
// Update state cost terms
solver->work->q = -(solver->work->Xref.array().colwise() * solver->work->Q.array());
(solver->work->q).noalias() -= solver->cache->rho * (solver->work->vnew - solver->work->g);
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
(solver->work->q).noalias() -=
solver->cache->rho * (solver->work->vcnew - solver->work->gc);
}
if (solver->settings->en_state_linear) {
(solver->work->q).noalias() -=
solver->cache->rho * (solver->work->vlnew - solver->work->gl);
}
// Update input cost terms
solver->work->r = -(solver->work->Uref.array().colwise() * solver->work->R.array());
(solver->work->r).noalias() -= solver->cache->rho * (solver->work->znew - solver->work->y);
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
(solver->work->r).noalias() -=
solver->cache->rho * (solver->work->zcnew - solver->work->yc);
}
if (solver->settings->en_input_linear) {
(solver->work->r).noalias() -=
solver->cache->rho * (solver->work->zlnew - solver->work->yl);
}
// Update terminal cost
solver->work->p.col(solver->work->N - 1) =
-(solver->work->Xref.col(solver->work->N - 1).transpose().lazyProduct(solver->cache->Pinf));
(solver->work->p.col(solver->work->N - 1)).noalias() -= solver->cache->rho
* (solver->work->vnew.col(solver->work->N - 1) - solver->work->g.col(solver->work->N - 1));
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
* (solver->work->vcnew.col(solver->work->N - 1)
- solver->work->gc.col(solver->work->N - 1));
}
if (solver->settings->en_state_linear) {
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
* (solver->work->vlnew.col(solver->work->N - 1)
- solver->work->gl.col(solver->work->N - 1));
}
}
/**
* Check for termination condition by evaluating whether the largest absolute
* primal and dual residuals for states and inputs are below threhold.
*/
bool termination_condition(TinySolver* solver) {
if (solver->work->iter % solver->settings->check_termination == 0) {
solver->work->primal_residual_state =
(solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
solver->work->dual_residual_state =
((solver->work->v - solver->work->vnew).cwiseAbs().maxCoeff()) * solver->cache->rho;
solver->work->primal_residual_input =
(solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
solver->work->dual_residual_input =
((solver->work->z - solver->work->znew).cwiseAbs().maxCoeff()) * solver->cache->rho;
if (solver->work->primal_residual_state < solver->settings->abs_pri_tol
&& solver->work->primal_residual_input < solver->settings->abs_pri_tol
&& solver->work->dual_residual_state < solver->settings->abs_dua_tol
&& solver->work->dual_residual_input < solver->settings->abs_dua_tol)
{
return true;
}
}
return false;
}
int solve(TinySolver* solver) {
// Initialize variables
solver->solution->solved = 0;
solver->solution->iter = 0;
solver->work->status = 11; // TINY_UNSOLVED
solver->work->iter = 0;
// Setup for adaptive rho
RhoAdapter adapter;
adapter.rho_min = solver->settings->adaptive_rho_min;
adapter.rho_max = solver->settings->adaptive_rho_max;
adapter.clip = solver->settings->adaptive_rho_enable_clipping;
RhoBenchmarkResult rho_result;
// Store previous values for residuals
tinyMatrix v_prev = solver->work->vnew;
tinyMatrix z_prev = solver->work->znew;
// Initialize SOC slack variables if needed
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
solver->work->vcnew = solver->work->x;
}
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
solver->work->zcnew = solver->work->u;
}
// Initialize linear constraint slack variables if needed
if (solver->settings->en_state_linear) {
solver->work->vlnew = solver->work->x;
}
if (solver->settings->en_input_linear) {
solver->work->zlnew = solver->work->u;
}
for (int i = 0; i < solver->settings->max_iter; i++) {
// Solve linear system with Riccati and roll out to get new trajectory
backward_pass_grad(solver);
forward_pass(solver);
// Project slack variables into feasible domain
update_slack(solver);
// Compute next iteration of dual variables
update_dual(solver);
// Update linear control cost terms using reference trajectory, duals, and slack variables
update_linear_cost(solver);
solver->work->iter += 1;
// Handle adaptive rho if enabled
if (solver->settings->adaptive_rho) {
// Calculate residuals for adaptive rho
tinytype pri_res_input = (solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
tinytype pri_res_state = (solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
tinytype dua_res_input =
solver->cache->rho * (solver->work->znew - z_prev).cwiseAbs().maxCoeff();
tinytype dua_res_state =
solver->cache->rho * (solver->work->vnew - v_prev).cwiseAbs().maxCoeff();
// Update rho every 5 iterations
if (i > 0 && i % 5 == 0) {
benchmark_rho_adaptation(
&adapter,
solver->work->x,
solver->work->u,
solver->work->vnew,
solver->work->znew,
solver->work->g,
solver->work->y,
solver->cache,
solver->work,
solver->work->N,
&rho_result
);
// Update matrices using Taylor expansion
update_matrices_with_derivatives(solver->cache, rho_result.final_rho);
}
}
// Store previous values for next iteration
z_prev = solver->work->znew;
v_prev = solver->work->vnew;
// Check for whether cost is minimized by calculating residuals
if (termination_condition(solver)) {
solver->work->status = 1; // TINY_SOLVED
// Save solution
solver->solution->iter = solver->work->iter;
solver->solution->solved = 1;
solver->solution->x = solver->work->vnew;
solver->solution->u = solver->work->znew;
// std::cout << "Solver converged in " << solver->work->iter << " iterations" << std::endl;
return 0;
}
// Save previous slack variables
solver->work->v = solver->work->vnew;
solver->work->z = solver->work->znew;
}
solver->solution->iter = solver->work->iter;
solver->solution->solved = 0;
solver->solution->x = solver->work->vnew;
solver->solution->u = solver->work->znew;
return 1;
}
} /* extern "C" */

View File

@@ -0,0 +1,37 @@
#pragma once
#include "types.hpp"
#ifdef __cplusplus
extern "C" {
#endif
int solve(TinySolver* solver);
void update_primal(TinySolver* solver);
void backward_pass_grad(TinySolver* solver);
void forward_pass(TinySolver* solver);
void update_slack(TinySolver* solver);
void update_dual(TinySolver* solver);
void update_linear_cost(TinySolver* solver);
bool termination_condition(TinySolver* solver);
/**
* Project a vector s onto the second order cone defined by mu
* @param s, mu
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
*/
tinyVector project_soc(tinyVector s, float mu);
/**
* Project a vector z onto a hyperplane defined by a^T z = b
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ b)/||a||² * a
* @param z Vector to project
* @param a Normal vector of the hyperplane
* @param b Offset of the hyperplane
* @return Projection of z onto the hyperplane
*/
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,466 @@
#include <ctype.h>
#include <dirent.h>
#include <stdio.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <time.h>
#include <unistd.h>
//#include <error.h>
#include "error.hpp"
#include <Eigen/Dense>
#include <iostream>
// #include "types.hpp"
#include "codegen.hpp"
#ifdef __MINGW32__
#include <direct.h>
inline int mkdir(const char* pathname, int flags) {
return _mkdir(pathname);
}
#endif
#ifdef __cplusplus
extern "C" {
#endif
/* Define the maximum allowed length of the path (directory + filename + extension) */
#define PATH_LENGTH 2048
using namespace Eigen;
static void print_matrix(FILE* f, MatrixXd mat, int num_elements) {
// Check if matrix is uninitialized or too small
if (mat.size() == 0 || mat.size() < num_elements) {
// Print zeros for all elements
for (int i = 0; i < num_elements; i++) {
fprintf(f, "(tinytype)0.0000000000000000");
if (i < num_elements - 1)
fprintf(f, ",");
}
return;
}
// Matrix is properly initialized and has enough elements
for (int i = 0; i < num_elements; i++) {
fprintf(f, "(tinytype)%.16f", mat.reshaped<RowMajor>()[i]);
if (i < num_elements - 1)
fprintf(f, ",");
}
}
static void create_directory(const char* dir, int verbose) {
// Attempt to create directory
if (mkdir(dir, S_IRWXU | S_IRWXG | S_IROTH)) {
if (errno == EEXIST) { // Skip if directory already exists
if (verbose)
std::cout << dir << " already exists, skipping." << std::endl;
} else {
ERROR_MSG(EXIT_FAILURE, "Failed to create directory %s", dir);
}
}
}
// TODO: Make this fail if tiny_setup has not already been called
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose) {
if (!solver) {
std::cout << "Error in tiny_codegen: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= codegen_create_directories(output_dir, verbose);
status |= codegen_data_header(output_dir, verbose);
status |= codegen_data_source(solver, output_dir, verbose);
status |= codegen_example(output_dir, verbose);
return status;
}
int tiny_codegen_with_sensitivity(
TinySolver* solver,
const char* output_dir,
tinyMatrix* dK,
tinyMatrix* dP,
tinyMatrix* dC1,
tinyMatrix* dC2,
int verbose
) {
if (!solver) {
std::cout << "Error in tiny_codegen_with_sensitivity: solver is nullptr" << std::endl;
return 1;
}
// Only store sensitivity matrices if adaptive rho is enabled
if (solver->settings->adaptive_rho) {
// Store the sensitivity matrices in the solver's cache
solver->cache->dKinf_drho = *dK;
solver->cache->dPinf_drho = *dP;
solver->cache->dC1_drho = *dC1;
solver->cache->dC2_drho = *dC2;
}
// Call the regular codegen function which will now include the sensitivity matrices if adaptive_rho is enabled
return tiny_codegen(solver, output_dir, verbose);
}
// Create code generation folder structure in whichever directory the executable calling tiny_codegen was called
int codegen_create_directories(const char* output_dir, int verbose) {
// Create output folder (root folder for code generation)
create_directory(output_dir, verbose);
// Create src folder
char src_dir[PATH_LENGTH];
sprintf(src_dir, "%s/src/", output_dir);
create_directory(src_dir, verbose);
// Create tinympc folder
char tinympc_dir[PATH_LENGTH];
sprintf(tinympc_dir, "%s/tinympc/", output_dir);
create_directory(tinympc_dir, verbose);
// // Create include folder
// char inc_dir[PATH_LENGTH];
// sprintf(inc_dir, "%s/include/", output_dir);
// create_directory(inc_dir, verbose);
return EXIT_SUCCESS;
}
// Create inc/tiny_data.hpp file
int codegen_data_header(const char* output_dir, int verbose) {
char data_hpp_fname[PATH_LENGTH];
FILE* data_hpp_f;
sprintf(data_hpp_fname, "%s/tinympc/tiny_data.hpp", output_dir);
// Open data header file
data_hpp_f = fopen(data_hpp_fname, "w+");
if (data_hpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_hpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(data_hpp_f, "/*\n");
fprintf(data_hpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(data_hpp_f, " */\n\n");
fprintf(data_hpp_f, "#pragma once\n\n");
fprintf(data_hpp_f, "#include \"types.hpp\"\n\n");
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
fprintf(data_hpp_f, "extern \"C\" {\n");
fprintf(data_hpp_f, "#endif\n\n");
fprintf(data_hpp_f, "extern TinySolver tiny_solver;\n\n");
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
fprintf(data_hpp_f, "}\n");
fprintf(data_hpp_f, "#endif\n");
// Close codegen data header file
fclose(data_hpp_f);
if (verbose) {
printf("Data header generated in %s\n", data_hpp_fname);
}
return 0;
}
// Create src/tiny_data.cpp file
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose) {
char data_cpp_fname[PATH_LENGTH];
FILE* data_cpp_f;
int nx = solver->work->nx;
int nu = solver->work->nu;
int N = solver->work->N;
sprintf(data_cpp_fname, "%s/src/tiny_data.cpp", output_dir);
// Open data source file
data_cpp_f = fopen(data_cpp_fname, "w+");
if (data_cpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_cpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(data_cpp_f, "/*\n");
fprintf(data_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(data_cpp_f, " */\n\n");
// Open extern C
fprintf(data_cpp_f, "#include \"tinympc/tiny_data.hpp\"\n\n");
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
fprintf(data_cpp_f, "extern \"C\" {\n");
fprintf(data_cpp_f, "#endif\n\n");
// Solution
fprintf(data_cpp_f, "/* Solution */\n");
fprintf(data_cpp_f, "TinySolution solution = {\n");
fprintf(data_cpp_f, "\t%d,\t\t// iter\n", solver->solution->iter);
fprintf(data_cpp_f, "\t%d,\t\t// solved\n", solver->solution->solved);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x solution
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// x\n"); // u solution
fprintf(data_cpp_f, "};\n\n");
// Cache
fprintf(data_cpp_f, "/* Matrices that must be recomputed with changes in time step, rho */\n");
fprintf(data_cpp_f, "TinyCache cache = {\n");
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// rho (step size/penalty)\n", solver->cache->rho);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
print_matrix(data_cpp_f, solver->cache->Kinf, nu * nx);
fprintf(data_cpp_f, ").finished(),\t// Kinf\n"); // Kinf
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->Pinf, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// Pinf\n"); // Pinf
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nu);
print_matrix(data_cpp_f, solver->cache->Quu_inv, nu * nu);
fprintf(data_cpp_f, ").finished(),\t// Quu_inv\n"); // Quu_inv
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->AmBKt, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// AmBKt\n"); // AmBKt
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->C1, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// C1\n"); // C1
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->C2, nx * nx);
fprintf(data_cpp_f, ").finished()"); // C2, no comma if no sensitivity matrices
// Only print sensitivity matrices if adaptive rho is enabled
if (solver->settings->adaptive_rho) {
fprintf(data_cpp_f, ",\t// C2\n"); // Add comma and comment for C2 if we have more matrices
// Add sensitivity matrices within the struct initialization
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
print_matrix(data_cpp_f, solver->cache->dKinf_drho, nu * nx);
fprintf(data_cpp_f, ").finished(),\t// dKinf_drho\n"); // dKinf_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dPinf_drho, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// dPinf_drho\n"); // dPinf_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dC1_drho, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// dC1_drho\n"); // dC1_drho
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->cache->dC2_drho, nx * nx);
fprintf(data_cpp_f, ").finished()\t// dC2_drho\n"); // dC2_drho
} else {
fprintf(data_cpp_f, "\t// C2\n"); // Just add comment for C2
}
fprintf(data_cpp_f, "};\n\n");
// Settings
fprintf(data_cpp_f, "/* User settings */\n");
fprintf(data_cpp_f, "TinySettings settings = {\n");
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// primal tolerance\n", solver->settings->abs_pri_tol);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// dual tolerance\n", solver->settings->abs_dua_tol);
fprintf(data_cpp_f, "\t%d,\t\t// max iterations\n", solver->settings->max_iter);
fprintf(
data_cpp_f,
"\t%d,\t\t// iterations per termination check\n",
solver->settings->check_termination
);
fprintf(data_cpp_f, "\t%d,\t\t// enable state constraints\n", solver->settings->en_state_bound);
fprintf(data_cpp_f, "\t%d\t\t// enable input constraints\n", solver->settings->en_input_bound);
fprintf(data_cpp_f, "};\n\n");
// Workspace
fprintf(data_cpp_f, "/* Problem variables */\n");
fprintf(data_cpp_f, "TinyWorkspace work = {\n");
fprintf(data_cpp_f, "\t%d,\t// Number of states\n", nx);
fprintf(data_cpp_f, "\t%d,\t// Number of control inputs\n", nu);
fprintf(data_cpp_f, "\t%d,\t// Number of knotpoints in the horizon\n", N);
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u\n"); // u
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// q\n"); // q
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// r\n"); // r
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// p\n"); // p
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// d\n"); // d
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// v\n"); // v
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// vnew\n"); // vnew
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// z\n"); // z
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// znew\n"); // znew
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// g\n"); // g
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// y\n"); // y
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nx);
print_matrix(data_cpp_f, solver->work->Q, nx);
fprintf(data_cpp_f, ").finished(),\t// Q\n"); // Q
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
print_matrix(data_cpp_f, solver->work->R, nu);
fprintf(data_cpp_f, ").finished(),\t// R\n"); // R
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
print_matrix(data_cpp_f, solver->work->Adyn, nx * nx);
fprintf(data_cpp_f, ").finished(),\t// Adyn\n"); // Adyn
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nu);
print_matrix(data_cpp_f, solver->work->Bdyn, nx * nu);
fprintf(data_cpp_f, ").finished(),\t// Bdyn\n"); // Bdyn
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, solver->work->x_min, nx * N);
fprintf(data_cpp_f, ").finished(),\t// x_min\n"); // x_min
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, solver->work->x_max, nx * N);
fprintf(data_cpp_f, ").finished(),\t// x_max\n"); // x_max
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, solver->work->u_min, nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u_min\n"); // u_min
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, solver->work->u_max, nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// u_max\n"); // u_max
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
fprintf(data_cpp_f, ").finished(),\t// Xref\n"); // Xref
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
fprintf(data_cpp_f, ").finished(),\t// Uref\n"); // Uref
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
print_matrix(data_cpp_f, MatrixXd::Zero(nu, 1), nu);
fprintf(data_cpp_f, ").finished(),\t// Qu\n"); // Qu
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state primal residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input primal residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state dual residual\n", 0.0);
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input dual residual\n", 0.0);
fprintf(data_cpp_f, "\t%d,\t// solve status\n", 0);
fprintf(data_cpp_f, "\t%d,\t// solve iteration\n", 0);
fprintf(data_cpp_f, "};\n\n");
// Write solver struct definition to workspace file
fprintf(data_cpp_f, "TinySolver tiny_solver = {&solution, &settings, &cache, &work};\n\n");
// Close extern C
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
fprintf(data_cpp_f, "}\n");
fprintf(data_cpp_f, "#endif\n\n");
// Close codegen data file
fclose(data_cpp_f);
if (verbose) {
printf("Data generated in %s\n", data_cpp_fname);
}
return 0;
}
int codegen_example(const char* output_dir, int verbose) {
char example_cpp_fname[PATH_LENGTH];
FILE* example_cpp_f;
sprintf(example_cpp_fname, "%s/src/tiny_main.cpp", output_dir);
// Open example file
example_cpp_f = fopen(example_cpp_fname, "w+");
if (example_cpp_f == NULL)
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", example_cpp_fname);
// Preamble
time_t start_time;
time(&start_time);
fprintf(example_cpp_f, "/*\n");
fprintf(example_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
fprintf(example_cpp_f, " */\n\n");
fprintf(example_cpp_f, "#include <iostream>\n\n");
fprintf(example_cpp_f, "#include <tinympc/tiny_api.hpp>\n");
fprintf(example_cpp_f, "#include <tinympc/tiny_data.hpp>\n\n");
fprintf(example_cpp_f, "using namespace Eigen;\n");
fprintf(example_cpp_f, "IOFormat TinyFmt(4, 0, \", \", \"\\n\", \"[\", \"]\");\n\n");
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
fprintf(example_cpp_f, "extern \"C\" {\n");
fprintf(example_cpp_f, "#endif\n\n");
fprintf(example_cpp_f, "int main()\n");
fprintf(example_cpp_f, "{\n");
fprintf(example_cpp_f, "\tint exitflag = 1;\n");
fprintf(example_cpp_f, "\t// Double check some data\n");
fprintf(example_cpp_f, "\tstd::cout << \"rho: \" << tiny_solver.cache->rho << std::endl;\n");
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nmax iters: \" << tiny_solver.settings->max_iter << std::endl;\n"
);
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nState transition matrix:\\n\" << tiny_solver.work->Adyn.format(TinyFmt) << std::endl;\n"
);
fprintf(
example_cpp_f,
"\tstd::cout << \"\\nInput/control matrix:\\n\" << tiny_solver.work->Bdyn.format(TinyFmt) << std::endl;\n\n"
);
fprintf(
example_cpp_f,
"\t// Visit https://tinympc.org/ to see how to set the initial condition and update the reference trajectory.\n\n"
);
fprintf(example_cpp_f, "\tstd::cout << \"\\nSolving...\\n\" << std::endl;\n\n");
fprintf(example_cpp_f, "\texitflag = tiny_solve(&tiny_solver);\n\n");
fprintf(example_cpp_f, "\tif (exitflag == 0) printf(\"Hooray! Solved with no error!\\n\");\n");
fprintf(example_cpp_f, "\telse printf(\"Oops! Something went wrong!\\n\");\n");
fprintf(example_cpp_f, "\treturn 0;\n");
fprintf(example_cpp_f, "}\n\n");
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
fprintf(example_cpp_f, "} /* extern \"C\" */\n");
fprintf(example_cpp_f, "#endif\n");
// Close codegen example main file
fclose(example_cpp_f);
if (verbose) {
printf("Example tinympc main generated in %s\n", example_cpp_fname);
}
return 0;
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,28 @@
#pragma once
#include "types.hpp"
#ifdef __cplusplus
extern "C" {
#endif
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose);
int tiny_codegen_with_sensitivity(
TinySolver* solver,
const char* output_dir,
tinyMatrix* dK,
tinyMatrix* dP,
tinyMatrix* dC1,
tinyMatrix* dC2,
int verbose
);
int codegen_create_directories(const char* output_dir, int verbose);
int codegen_data_header(const char* output_dir, int verbose);
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose);
int codegen_example(const char* output_dir, int verbose);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,29 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// #if defined(__linux__) || defined(__unix__)// Check if Linux
// #include <error.h>
// #define ERROR_MSG(exit_code, format, ...) error(exit_code, errno, format, ##__VA_ARGS__)
// #elif defined(__APPLE__) || defined(__MACH__) // Check if macOS
#define ERROR_MSG(exit_code, format, ...) \
{ \
fprintf(stderr, format ": %s\n", ##__VA_ARGS__, strerror(errno)); \
exit(exit_code); \
}
// #else
// #error "Unsupported operating system"
// #endif
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,252 @@
#include "rho_benchmark.hpp"
#include <algorithm>
#include <cmath>
#include <iostream>
#ifdef ARDUINO
#include <Arduino.h>
#else
// For non-Arduino platforms
uint32_t micros() {
return 0; // Replace with appropriate timing function
}
#endif
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N) {
// Calculate dimensions
int x_decision_size = nx * N + nu * (N - 1);
int constraint_rows = (nx + nu) * (N - 1);
// Pre-allocate matrices
adapter->A_matrix = tinyMatrix::Zero(constraint_rows, x_decision_size);
adapter->z_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->y_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->x_decision = tinyMatrix::Zero(x_decision_size, 1);
// Pre-compute P matrix structure
adapter->P_matrix = tinyMatrix::Zero(x_decision_size, x_decision_size);
adapter->q_vector = tinyMatrix::Zero(x_decision_size, 1);
// Pre-allocate residual computation matrices
adapter->Ax_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->r_prim_vector = tinyMatrix::Zero(constraint_rows, 1);
adapter->r_dual_vector = tinyMatrix::Zero(x_decision_size, 1);
adapter->Px_vector = tinyMatrix::Zero(x_decision_size, 1);
adapter->ATy_vector = tinyMatrix::Zero(x_decision_size, 1);
// Store dimensions
adapter->format_nx = nx;
adapter->format_nu = nu;
adapter->format_N = N;
adapter->matrices_initialized = true;
}
void format_matrices(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N
) {
if (!adapter->matrices_initialized) {
initialize_format_matrices(adapter, x_prev.rows(), u_prev.rows(), N);
}
int nx = adapter->format_nx;
int nu = adapter->format_nu;
// Fill x_decision
int x_idx = 0;
for (int i = 0; i < N; i++) {
adapter->x_decision.block(x_idx, 0, nx, 1) = x_prev.col(i);
x_idx += nx;
if (i < N - 1) {
adapter->x_decision.block(x_idx, 0, nu, 1) = u_prev.col(i);
x_idx += nu;
}
}
// Clear A matrix for reuse
adapter->A_matrix.setZero();
// Fill A matrix with dynamics and input constraints
for (int i = 0; i < N - 1; i++) {
// Input constraints
int row_start = i * nu;
int col_start = i * (nx + nu) + nx;
adapter->A_matrix.block(row_start, col_start, nu, nu) = tinyMatrix::Identity(nu, nu);
// Dynamics constraints
row_start = (N - 1) * nu + i * nx;
col_start = i * (nx + nu);
adapter->A_matrix.block(row_start, col_start, nx, nx) = work->Adyn;
adapter->A_matrix.block(row_start, col_start + nx, nx, nu) = work->Bdyn;
int next_state_idx = col_start + nx + nu;
if (next_state_idx < adapter->A_matrix.cols()) {
adapter->A_matrix.block(row_start, next_state_idx, nx, nx) =
-tinyMatrix::Identity(nx, nx);
}
}
// Fill z and y vectors
for (int i = 0; i < N - 1; i++) {
adapter->z_vector.block(i * nu, 0, nu, 1) = z_prev.col(i);
adapter->z_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = v_prev.col(i + 1);
adapter->y_vector.block(i * nu, 0, nu, 1) = y_prev.col(i);
adapter->y_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = g_prev.col(i + 1);
}
// Build P matrix (cost matrix)
adapter->P_matrix.setZero();
// Fill diagonal blocks
x_idx = 0;
for (int i = 0; i < N; i++) {
// State cost
if (i == N - 1) {
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = cache->Pinf;
} else {
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = work->Q.asDiagonal();
}
x_idx += nx;
// Input cost
if (i < N - 1) {
adapter->P_matrix.block(x_idx, x_idx, nu, nu) = work->R.asDiagonal();
x_idx += nu;
}
}
// Create q vector (linear cost vector)
x_idx = 0;
for (int i = 0; i < N; i++) {
// For simplicity, we'll use zero reference for now
// In a real implementation, you'd use your reference trajectory
tinyMatrix x_ref = tinyMatrix::Zero(nx, 1);
tinyMatrix delta_x = x_prev.col(i) - x_ref;
adapter->q_vector.block(x_idx, 0, nx, 1) = work->Q.asDiagonal() * delta_x;
x_idx += nx;
if (i < N - 1) {
// For simplicity, we'll use zero reference for now
tinyMatrix u_ref = tinyMatrix::Zero(nu, 1);
tinyMatrix delta_u = u_prev.col(i) - u_ref;
adapter->q_vector.block(x_idx, 0, nu, 1) = work->R.asDiagonal() * delta_u;
x_idx += nu;
}
}
}
void compute_residuals(
RhoAdapter* adapter,
tinytype* pri_res,
tinytype* dual_res,
tinytype* pri_norm,
tinytype* dual_norm
) {
// Compute Ax
adapter->Ax_vector = adapter->A_matrix * adapter->x_decision;
// Compute primal residual
adapter->r_prim_vector = adapter->Ax_vector - adapter->z_vector;
*pri_res = adapter->r_prim_vector.cwiseAbs().maxCoeff();
*pri_norm =
std::max(adapter->Ax_vector.cwiseAbs().maxCoeff(), adapter->z_vector.cwiseAbs().maxCoeff());
// Compute dual residual components
adapter->Px_vector = adapter->P_matrix * adapter->x_decision;
adapter->ATy_vector = adapter->A_matrix.transpose() * adapter->y_vector;
// Compute full dual residual
adapter->r_dual_vector = adapter->Px_vector + adapter->q_vector + adapter->ATy_vector;
*dual_res = adapter->r_dual_vector.cwiseAbs().maxCoeff();
// Compute normalization
*dual_norm = std::max(
std::max(
adapter->Px_vector.cwiseAbs().maxCoeff(),
adapter->ATy_vector.cwiseAbs().maxCoeff()
),
adapter->q_vector.cwiseAbs().maxCoeff()
);
}
tinytype predict_rho(
RhoAdapter* adapter,
tinytype pri_res,
tinytype dual_res,
tinytype pri_norm,
tinytype dual_norm,
tinytype current_rho
) {
const tinytype eps = 1e-10;
tinytype normalized_pri = pri_res / (pri_norm + eps);
tinytype normalized_dual = dual_res / (dual_norm + eps);
tinytype ratio = normalized_pri / (normalized_dual + eps);
tinytype new_rho = current_rho * std::sqrt(ratio);
if (adapter->clip) {
new_rho = std::min(std::max(new_rho, adapter->rho_min), adapter->rho_max);
}
return new_rho;
}
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho) {
tinytype delta_rho = new_rho - cache->rho;
cache->Kinf = cache->Kinf + delta_rho * cache->dKinf_drho;
cache->Pinf = cache->Pinf + delta_rho * cache->dPinf_drho;
cache->C1 = cache->C1 + delta_rho * cache->dC1_drho;
cache->C2 = cache->C2 + delta_rho * cache->dC2_drho;
cache->rho = new_rho;
}
void benchmark_rho_adaptation(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N,
RhoBenchmarkResult* result
) {
uint32_t start_time = micros();
// Format matrices
format_matrices(adapter, x_prev, u_prev, v_prev, z_prev, g_prev, y_prev, cache, work, N);
// Compute residuals
tinytype pri_res, dual_res, pri_norm, dual_norm;
compute_residuals(adapter, &pri_res, &dual_res, &pri_norm, &dual_norm);
// Predict new rho
tinytype new_rho = predict_rho(adapter, pri_res, dual_res, pri_norm, dual_norm, cache->rho);
// Update matrices
update_matrices_with_derivatives(cache, new_rho);
// Store results
result->time_us = micros() - start_time;
result->initial_rho = cache->rho;
result->final_rho = new_rho;
result->pri_res = pri_res;
result->dual_res = dual_res;
result->pri_norm = pri_norm;
result->dual_norm = dual_norm;
}

View File

@@ -0,0 +1,94 @@
#pragma once
#include "types.hpp"
#include <cstdint>
struct RhoAdapter {
tinytype rho_min;
tinytype rho_max;
bool clip;
bool matrices_initialized;
// Pre-allocated matrices for formatting
tinyMatrix A_matrix;
tinyMatrix z_vector;
tinyMatrix y_vector;
tinyMatrix x_decision;
tinyMatrix P_matrix;
tinyMatrix q_vector;
// Pre-allocated matrices for residual computation
tinyMatrix Ax_vector;
tinyMatrix r_prim_vector;
tinyMatrix r_dual_vector;
tinyMatrix Px_vector;
tinyMatrix ATy_vector;
// Dimensions
int format_nx;
int format_nu;
int format_N;
};
struct RhoBenchmarkResult {
uint32_t time_us;
tinytype initial_rho;
tinytype final_rho;
tinytype pri_res;
tinytype dual_res;
tinytype pri_norm;
tinytype dual_norm;
};
// Initialize matrices for formatting
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N);
// Format matrices for residual computation
void format_matrices(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N
);
// Compute residuals
void compute_residuals(
RhoAdapter* adapter,
tinytype* pri_res,
tinytype* dual_res,
tinytype* pri_norm,
tinytype* dual_norm
);
// Predict new rho value
tinytype predict_rho(
RhoAdapter* adapter,
tinytype pri_res,
tinytype dual_res,
tinytype pri_norm,
tinytype dual_norm,
tinytype current_rho
);
// Update matrices using derivatives
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho);
// Main benchmark function
void benchmark_rho_adaptation(
RhoAdapter* adapter,
const tinyMatrix& x_prev,
const tinyMatrix& u_prev,
const tinyMatrix& v_prev,
const tinyMatrix& z_prev,
const tinyMatrix& g_prev,
const tinyMatrix& y_prev,
TinyCache* cache,
TinyWorkspace* work,
int N,
RhoBenchmarkResult* result
);

View File

@@ -0,0 +1,876 @@
#include "tiny_api.hpp"
#include "tiny_api_constants.hpp"
#include <iostream>
#ifdef __cplusplus
extern "C" {
#endif
using namespace Eigen;
IOFormat TinyApiFmt(4, 0, ", ", "\n", "[", "]");
static int
check_dimension(std::string matrix_name, std::string rows_or_columns, int actual, int expected) {
if (actual != expected) {
std::cout << matrix_name << " has " << actual << " " << rows_or_columns << ". Expected "
<< expected << "." << std::endl;
return 1;
}
return 0;
}
int tiny_setup(
TinySolver** solverp,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
tinytype rho,
int nx,
int nu,
int N,
int verbose
) {
TinySolution* solution = new TinySolution();
TinyCache* cache = new TinyCache();
TinySettings* settings = new TinySettings();
TinyWorkspace* work = new TinyWorkspace();
TinySolver* solver = new TinySolver();
solver->solution = solution;
solver->cache = cache;
solver->settings = settings;
solver->work = work;
*solverp = solver;
// Initialize solution
solution->iter = 0;
solution->solved = 0;
solution->x = tinyMatrix::Zero(nx, N);
solution->u = tinyMatrix::Zero(nu, N - 1);
// Initialize settings
tiny_set_default_settings(settings);
// Initialize workspace
work->nx = nx;
work->nu = nu;
work->N = N;
// Make sure arguments are the correct shapes
int status = 0;
status |= check_dimension("State transition matrix (A)", "rows", Adyn.rows(), nx);
status |= check_dimension("State transition matrix (A)", "columns", Adyn.cols(), nx);
status |= check_dimension("Input matrix (B)", "rows", Bdyn.rows(), nx);
status |= check_dimension("Input matrix (B)", "columns", Bdyn.cols(), nu);
status |= check_dimension("Affine vector (f)", "rows", fdyn.rows(), nx);
status |= check_dimension("Affine vector (f)", "columns", fdyn.cols(), 1);
status |= check_dimension("State stage cost (Q)", "rows", Q.rows(), nx);
status |= check_dimension("State stage cost (Q)", "columns", Q.cols(), nx);
status |= check_dimension("State input cost (R)", "rows", R.rows(), nu);
status |= check_dimension("State input cost (R)", "columns", R.cols(), nu);
if (status) {
return status;
}
work->x = tinyMatrix::Zero(nx, N);
work->u = tinyMatrix::Zero(nu, N - 1);
work->q = tinyMatrix::Zero(nx, N);
work->r = tinyMatrix::Zero(nu, N - 1);
work->p = tinyMatrix::Zero(nx, N);
work->d = tinyMatrix::Zero(nu, N - 1);
// Bound constraint slack variables
work->v = tinyMatrix::Zero(nx, N);
work->vnew = tinyMatrix::Zero(nx, N);
work->z = tinyMatrix::Zero(nu, N - 1);
work->znew = tinyMatrix::Zero(nu, N - 1);
// Bound constraint dual variables
work->g = tinyMatrix::Zero(nx, N);
work->y = tinyMatrix::Zero(nu, N - 1);
// Cone constraint slack variables
work->vc = tinyMatrix::Zero(nx, N);
work->vcnew = tinyMatrix::Zero(nx, N);
work->zc = tinyMatrix::Zero(nu, N - 1);
work->zcnew = tinyMatrix::Zero(nu, N - 1);
// Cone constraint dual variables
work->gc = tinyMatrix::Zero(nx, N);
work->yc = tinyMatrix::Zero(nu, N - 1);
// Linear constraint slack variables
work->vl = tinyMatrix::Zero(nx, N);
work->vlnew = tinyMatrix::Zero(nx, N);
work->zl = tinyMatrix::Zero(nu, N - 1);
work->zlnew = tinyMatrix::Zero(nu, N - 1);
// Linear constraint dual variables
work->gl = tinyMatrix::Zero(nx, N);
work->yl = tinyMatrix::Zero(nu, N - 1);
work->Q = (Q + rho * tinyMatrix::Identity(nx, nx)).diagonal();
work->R = (R + rho * tinyMatrix::Identity(nu, nu)).diagonal();
work->Adyn = Adyn; // State transition matrix
work->Bdyn = Bdyn; // Input matrix
work->fdyn = fdyn; // Affine offset vector
work->Xref = tinyMatrix::Zero(nx, N);
work->Uref = tinyMatrix::Zero(nu, N - 1);
work->Qu = tinyVector::Zero(nu);
work->primal_residual_state = 0;
work->primal_residual_input = 0;
work->dual_residual_state = 0;
work->dual_residual_input = 0;
work->status = 0;
work->iter = 0;
// Initialize cache
status = tiny_precompute_and_set_cache(
cache,
Adyn,
Bdyn,
fdyn,
work->Q.asDiagonal(),
work->R.asDiagonal(),
nx,
nu,
rho,
verbose
);
if (status) {
return status;
}
// Initialize sensitivity matrices for adaptive rho
if (solver->settings->adaptive_rho) {
tiny_initialize_sensitivity_matrices(solver);
}
return 0;
}
int tiny_set_bound_constraints(
TinySolver* solver,
tinyMatrix x_min,
tinyMatrix x_max,
tinyMatrix u_min,
tinyMatrix u_max
) {
if (!solver) {
std::cout << "Error in tiny_set_bound_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all bound constraint matrix sizes are self-consistent
int status = 0;
status |= check_dimension("Lower state bounds (x_min)", "rows", x_min.rows(), solver->work->nx);
status |= check_dimension("Lower state bounds (x_min)", "cols", x_min.cols(), solver->work->N);
status |= check_dimension("Lower state bounds (x_max)", "rows", x_max.rows(), solver->work->nx);
status |= check_dimension("Lower state bounds (x_max)", "cols", x_max.cols(), solver->work->N);
status |= check_dimension("Lower input bounds (u_min)", "rows", u_min.rows(), solver->work->nu);
status |=
check_dimension("Lower input bounds (u_min)", "cols", u_min.cols(), solver->work->N - 1);
status |= check_dimension("Lower input bounds (u_max)", "rows", u_max.rows(), solver->work->nu);
status |=
check_dimension("Lower input bounds (u_max)", "cols", u_max.cols(), solver->work->N - 1);
solver->work->x_min = x_min;
solver->work->x_max = x_max;
solver->work->u_min = u_min;
solver->work->u_max = u_max;
return 0;
}
int tiny_set_cone_constraints(
TinySolver* solver,
VectorXi Acx,
VectorXi qcx,
tinyVector cx,
VectorXi Acu,
VectorXi qcu,
tinyVector cu
) {
if (!solver) {
std::cout << "Error in tiny_set_cone_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all cone constraint vector sizes are self-consistent
int num_state_cones = Acx.rows();
int num_input_cones = Acu.rows();
int status = 0;
status |= check_dimension("Cone state size (qcx)", "rows", qcx.rows(), num_state_cones);
status |= check_dimension("Cone mu value for state (cx)", "rows", cx.rows(), num_state_cones);
status |= check_dimension("Cone input size (qcu)", "rows", qcu.rows(), num_input_cones);
status |= check_dimension("Cone mu value for input (cu)", "rows", cu.rows(), num_input_cones);
if (status) {
return status;
}
solver->work->numStateCones = num_state_cones;
solver->work->numInputCones = num_input_cones;
solver->work->Acx = Acx;
solver->work->qcx = qcx;
solver->work->cx = cx;
solver->work->Acu = Acu;
solver->work->qcu = qcu;
solver->work->cu = cu;
return 0;
}
int tiny_set_linear_constraints(
TinySolver* solver,
tinyMatrix Alin_x,
tinyVector blin_x,
tinyMatrix Alin_u,
tinyVector blin_u
) {
if (!solver) {
std::cout << "Error in tiny_set_linear_constraints: solver is nullptr" << std::endl;
return 1;
}
// Make sure all linear constraint matrix sizes are self-consistent
int num_state_linear = Alin_x.rows();
int num_input_linear = Alin_u.rows();
int status = 0;
// Check state constraint dimensions
if (num_state_linear > 0) {
status |= check_dimension(
"State linear constraint matrix (Alin_x)",
"rows",
Alin_x.rows(),
num_state_linear
);
status |= check_dimension(
"State linear constraint matrix (Alin_x)",
"columns",
Alin_x.cols(),
solver->work->nx
);
status |= check_dimension(
"State linear constraint vector (blin_x)",
"rows",
blin_x.rows(),
num_state_linear
);
status |=
check_dimension("State linear constraint vector (blin_x)", "columns", blin_x.cols(), 1);
}
// Check input constraint dimensions
if (num_input_linear > 0) {
status |= check_dimension(
"Input linear constraint matrix (Alin_u)",
"rows",
Alin_u.rows(),
num_input_linear
);
status |= check_dimension(
"Input linear constraint matrix (Alin_u)",
"columns",
Alin_u.cols(),
solver->work->nu
);
status |= check_dimension(
"Input linear constraint vector (blin_u)",
"rows",
blin_u.rows(),
num_input_linear
);
status |=
check_dimension("Input linear constraint vector (blin_u)", "columns", blin_u.cols(), 1);
}
if (status) {
return status;
}
solver->work->numStateLinear = num_state_linear;
solver->work->numInputLinear = num_input_linear;
solver->work->Alin_x = Alin_x;
solver->work->blin_x = blin_x;
solver->work->Alin_u = Alin_u;
solver->work->blin_u = blin_u;
return 0;
}
int tiny_precompute_and_set_cache(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
) {
if (!cache) {
std::cout << "Error in tiny_precompute_and_set_cache: cache is nullptr" << std::endl;
return 1;
}
// Update by adding rho * identity matrix to Q, R
tinyMatrix Q1 = Q + rho * tinyMatrix::Identity(nx, nx);
tinyMatrix R1 = R + rho * tinyMatrix::Identity(nu, nu);
// Printing
if (verbose) {
std::cout << "A = " << Adyn.format(TinyApiFmt) << std::endl;
std::cout << "B = " << Bdyn.format(TinyApiFmt) << std::endl;
std::cout << "Q = " << Q1.format(TinyApiFmt) << std::endl;
std::cout << "R = " << R1.format(TinyApiFmt) << std::endl;
std::cout << "rho = " << rho << std::endl;
}
// Riccati recursion to get Kinf, Pinf
tinyMatrix Ktp1 = tinyMatrix::Zero(nu, nx);
tinyMatrix Ptp1 = rho * tinyMatrix::Ones(nx, 1).array().matrix().asDiagonal();
tinyMatrix Kinf = tinyMatrix::Zero(nu, nx);
tinyMatrix Pinf = tinyMatrix::Zero(nx, nx);
for (int i = 0; i < 1000; i++) {
Kinf = (R1 + Bdyn.transpose() * Ptp1 * Bdyn).inverse() * Bdyn.transpose() * Ptp1 * Adyn;
Pinf = Q1 + Adyn.transpose() * Ptp1 * (Adyn - Bdyn * Kinf);
// if Kinf converges, break
if ((Kinf - Ktp1).cwiseAbs().maxCoeff() < 1e-5) {
if (verbose) {
std::cout << "Kinf converged after " << i + 1 << " iterations" << std::endl;
}
break;
}
Ktp1 = Kinf;
Ptp1 = Pinf;
}
// Compute cached matrices
tinyMatrix Quu_inv = (R1 + Bdyn.transpose() * Pinf * Bdyn).inverse();
tinyMatrix AmBKt = (Adyn - Bdyn * Kinf).transpose();
// Precomputation for affine term
tinyVector APf = AmBKt * Pinf * fdyn;
tinyVector BPf = Bdyn.transpose() * Pinf * fdyn;
if (verbose) {
std::cout << "Kinf = " << Kinf.format(TinyApiFmt) << std::endl;
std::cout << "Pinf = " << Pinf.format(TinyApiFmt) << std::endl;
std::cout << "Quu_inv = " << Quu_inv.format(TinyApiFmt) << std::endl;
std::cout << "AmBKt = " << AmBKt.format(TinyApiFmt) << std::endl;
std::cout << "APf = " << APf.format(TinyApiFmt) << std::endl;
std::cout << "BPf = " << BPf.format(TinyApiFmt) << std::endl;
std::cout << "\nPrecomputation finished!\n" << std::endl;
}
cache->rho = rho;
cache->Kinf = Kinf;
cache->Pinf = Pinf;
cache->Quu_inv = Quu_inv;
cache->AmBKt = AmBKt;
cache->C1 = Quu_inv;
cache->C2 = AmBKt;
cache->APf = APf;
cache->BPf = BPf;
return 0; // return success
}
int tiny_solve(TinySolver* solver) {
return solve(solver);
}
int tiny_update_settings(
TinySettings* settings,
tinytype abs_pri_tol,
tinytype abs_dua_tol,
int max_iter,
int check_termination,
int en_state_bound,
int en_input_bound,
int en_state_soc,
int en_input_soc,
int en_state_linear,
int en_input_linear
) {
if (!settings) {
std::cout << "Error in tiny_update_settings: settings is nullptr" << std::endl;
return 1;
}
settings->abs_pri_tol = abs_pri_tol;
settings->abs_dua_tol = abs_dua_tol;
settings->max_iter = max_iter;
settings->check_termination = check_termination;
settings->en_state_bound = en_state_bound;
settings->en_input_bound = en_input_bound;
settings->en_state_soc = en_state_soc;
settings->en_input_soc = en_input_soc;
settings->en_state_linear = en_state_linear;
settings->en_input_linear = en_input_linear;
return 0;
}
int tiny_set_default_settings(TinySettings* settings) {
if (!settings) {
std::cout << "Error in tiny_set_default_settings: settings is nullptr" << std::endl;
return 1;
}
settings->abs_pri_tol = TINY_DEFAULT_ABS_PRI_TOL;
settings->abs_dua_tol = TINY_DEFAULT_ABS_DUA_TOL;
settings->max_iter = TINY_DEFAULT_MAX_ITER;
settings->check_termination = TINY_DEFAULT_CHECK_TERMINATION;
// Turn off constraints until they are set by tiny_set_bound_constraints or tiny_set_cone_constraints
settings->en_state_bound = TINY_DEFAULT_EN_STATE_BOUND;
settings->en_input_bound = TINY_DEFAULT_EN_INPUT_BOUND;
settings->en_state_soc = TINY_DEFAULT_EN_STATE_SOC;
settings->en_input_soc = TINY_DEFAULT_EN_INPUT_SOC;
settings->en_state_linear = TINY_DEFAULT_EN_STATE_LINEAR;
settings->en_input_linear = TINY_DEFAULT_EN_INPUT_LINEAR;
// Initialize adaptive rho settings
// NOTE : Adaptive rho currently supports only quadrotor system
settings->adaptive_rho = 0; // Disabled by default
settings->adaptive_rho_min = 1.0;
settings->adaptive_rho_max = 100.0;
settings->adaptive_rho_enable_clipping = 1;
return 0;
}
int tiny_set_x0(TinySolver* solver, tinyVector x0) {
if (!solver) {
std::cout << "Error in tiny_set_x0: solver is nullptr" << std::endl;
return 1;
}
if (x0.rows() != solver->work->nx) {
perror("Error in tiny_set_x0: x0 is not the correct length");
}
solver->work->x.col(0) = x0;
return 0;
}
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref) {
if (!solver) {
std::cout << "Error in tiny_set_x_ref: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= check_dimension(
"State reference trajectory (x_ref)",
"rows",
x_ref.rows(),
solver->work->nx
);
status |= check_dimension(
"State reference trajectory (x_ref)",
"columns",
x_ref.cols(),
solver->work->N
);
solver->work->Xref = x_ref;
return 0;
}
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref) {
if (!solver) {
std::cout << "Error in tiny_set_u_ref: solver is nullptr" << std::endl;
return 1;
}
int status = 0;
status |= check_dimension(
"Control/input reference trajectory (u_ref)",
"rows",
u_ref.rows(),
solver->work->nu
);
status |= check_dimension(
"Control/input reference trajectory (u_ref)",
"columns",
u_ref.cols(),
solver->work->N - 1
);
solver->work->Uref = u_ref;
return 0;
}
void tiny_initialize_sensitivity_matrices(TinySolver* solver) {
int nu = solver->work->nu;
int nx = solver->work->nx;
// Initialize matrices with zeros
solver->cache->dKinf_drho = tinyMatrix::Zero(nu, nx);
solver->cache->dPinf_drho = tinyMatrix::Zero(nx, nx);
solver->cache->dC1_drho = tinyMatrix::Zero(nu, nu);
solver->cache->dC2_drho = tinyMatrix::Zero(nx, nx);
const float dKinf_drho[4][12] = { { 0.0001,
-0.0001,
-0.0025,
0.0003,
0.0007,
0.0050,
0.0001,
-0.0001,
-0.0008,
0.0000,
0.0001,
0.0008 },
{ -0.0001,
-0.0000,
-0.0025,
-0.0001,
-0.0006,
-0.0050,
-0.0001,
0.0000,
-0.0008,
-0.0000,
-0.0001,
-0.0008 },
{ 0.0000,
0.0000,
-0.0025,
0.0001,
0.0004,
0.0050,
0.0000,
0.0000,
-0.0008,
0.0000,
0.0000,
0.0008 },
{ -0.0000,
0.0001,
-0.0025,
-0.0003,
-0.0004,
-0.0050,
-0.0000,
0.0001,
-0.0008,
-0.0000,
-0.0000,
-0.0008 } };
const float dPinf_drho[12][12] = { { 0.0494,
-0.0045,
-0.0000,
0.0110,
0.1300,
-0.0283,
0.0280,
-0.0026,
-0.0000,
0.0004,
0.0070,
-0.0094 },
{ -0.0045,
0.0491,
0.0000,
-0.1320,
-0.0111,
0.0114,
-0.0026,
0.0279,
0.0000,
-0.0076,
-0.0004,
0.0038 },
{ -0.0000,
0.0000,
2.4450,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
1.2593,
0.0000,
0.0000,
0.0000 },
{ 0.0110,
-0.1320,
0.0000,
0.3913,
0.0592,
0.3108,
0.0080,
-0.0776,
0.0000,
0.0254,
0.0068,
0.0750 },
{ 0.1300,
-0.0111,
-0.0000,
0.0592,
0.4420,
0.7771,
0.0797,
-0.0081,
-0.0000,
0.0068,
0.0350,
0.1875 },
{ -0.0283,
0.0114,
-0.0000,
0.3108,
0.7771,
10.0441,
0.0272,
-0.0109,
0.0000,
0.0655,
0.1639,
2.6362 },
{ 0.0280,
-0.0026,
-0.0000,
0.0080,
0.0797,
0.0272,
0.0163,
-0.0016,
-0.0000,
0.0005,
0.0047,
0.0032 },
{ -0.0026,
0.0279,
0.0000,
-0.0776,
-0.0081,
-0.0109,
-0.0016,
0.0161,
0.0000,
-0.0046,
-0.0005,
-0.0013 },
{ -0.0000,
0.0000,
1.2593,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.9232,
0.0000,
0.0000,
0.0000 },
{ 0.0004,
-0.0076,
0.0000,
0.0254,
0.0068,
0.0655,
0.0005,
-0.0046,
0.0000,
0.0022,
0.0017,
0.0244 },
{ 0.0070,
-0.0004,
0.0000,
0.0068,
0.0350,
0.1639,
0.0047,
-0.0005,
0.0000,
0.0017,
0.0054,
0.0610 },
{ -0.0094,
0.0038,
0.0000,
0.0750,
0.1875,
2.6362,
0.0032,
-0.0013,
0.0000,
0.0244,
0.0610,
0.9869 } };
const float dC1_drho[4][4] = { { -0.0000, 0.0000, -0.0000, 0.0000 },
{ 0.0000, -0.0000, 0.0000, -0.0000 },
{ -0.0000, 0.0000, -0.0000, 0.0000 },
{ 0.0000, -0.0000, 0.0000, -0.0000 } };
const float dC2_drho[12][12] = { { 0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000 },
{ -0.0000,
0.0000,
0.0001,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000 },
{ 0.0000,
-0.0000,
-0.0000,
0.0001,
0.0000,
-0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000 },
{ 0.0000,
-0.0000,
-0.0000,
0.0000,
0.0001,
-0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
0.0001,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000 },
{ 0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
0.0000,
-0.0000 },
{ -0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000,
-0.0000,
0.0000,
0.0000,
-0.0000,
-0.0000,
0.0000 },
{ -0.0000,
0.0000,
0.0021,
0.0000,
-0.0000,
-0.0000,
-0.0000,
0.0000,
0.0006,
0.0000,
-0.0000,
-0.0000 },
{ 0.0002,
-0.0027,
-0.0000,
0.0068,
0.0005,
-0.0005,
0.0001,
-0.0015,
-0.0000,
0.0004,
0.0000,
-0.0001 },
{ 0.0027,
-0.0002,
0.0000,
0.0005,
0.0066,
-0.0011,
0.0015,
-0.0001,
0.0000,
0.0000,
0.0004,
-0.0002 },
{ -0.0001,
0.0001,
0.0000,
-0.0000,
0.0000,
0.0041,
-0.0000,
0.0000,
0.0000,
0.0000,
0.0000,
0.0006 } };
// Map arrays to Eigen matrices
solver->cache->dKinf_drho = Map<const Matrix<float, 4, 12>>(dKinf_drho[0]).cast<tinytype>();
solver->cache->dPinf_drho = Map<const Matrix<float, 12, 12>>(dPinf_drho[0]).cast<tinytype>();
solver->cache->dC1_drho = Map<const Matrix<float, 4, 4>>(dC1_drho[0]).cast<tinytype>();
solver->cache->dC2_drho = Map<const Matrix<float, 12, 12>>(dC2_drho[0]).cast<tinytype>();
}
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,118 @@
#pragma once
#include "admm.hpp"
#include <iostream>
#ifdef __cplusplus
extern "C" {
#endif
int tiny_setup(
TinySolver** solverp,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
tinytype rho,
int nx,
int nu,
int N,
int verbose
);
int tiny_set_bound_constraints(
TinySolver* solver,
tinyMatrix x_min,
tinyMatrix x_max,
tinyMatrix u_min,
tinyMatrix u_max
);
int tiny_set_cone_constraints(
TinySolver* solver,
VectorXi Acu,
VectorXi qcu,
tinyVector cu,
VectorXi Acx,
VectorXi qcx,
tinyVector cx
);
int tiny_set_linear_constraints(
TinySolver* solver,
tinyMatrix Alin_x,
tinyVector blin_x,
tinyMatrix Alin_u,
tinyVector blin_u
);
int tiny_precompute_and_set_cache(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix fdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
);
void compute_sensitivity_matrices(
TinyCache* cache,
tinyMatrix Adyn,
tinyMatrix Bdyn,
tinyMatrix Q,
tinyMatrix R,
int nx,
int nu,
tinytype rho,
int verbose
);
int tiny_update_matrices_with_derivatives(TinyCache* cache, tinytype delta_rho);
int tiny_solve(TinySolver* solver);
int tiny_update_settings(
TinySettings* settings,
tinytype abs_pri_tol,
tinytype abs_dua_tol,
int max_iter,
int check_termination,
int en_state_bound,
int en_input_bound,
int en_state_soc,
int en_input_soc,
int en_state_linear,
int en_input_linear
);
int tiny_set_default_settings(TinySettings* settings);
int tiny_set_x0(TinySolver* solver, tinyVector x0);
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref);
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref);
/**
* Initialize sensitivity matrices for adaptive rho
*
* @param solver Pointer to solver
*/
void tiny_initialize_sensitivity_matrices(TinySolver* solver);
int tiny_setup_state_soc_constraints(
TinySolver* solver,
tinyVector Acx,
tinyVector qcx,
tinyVector cx,
int numStateCones
);
int tiny_setup_input_soc_constraints(
TinySolver* solver,
tinyVector Acu,
tinyVector qcu,
tinyVector cu,
int numInputCones
);
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,13 @@
#pragma once
// Default settings
#define TINY_DEFAULT_ABS_PRI_TOL (1e-03)
#define TINY_DEFAULT_ABS_DUA_TOL (1e-03)
#define TINY_DEFAULT_MAX_ITER (1000)
#define TINY_DEFAULT_CHECK_TERMINATION (1)
#define TINY_DEFAULT_EN_STATE_BOUND (1)
#define TINY_DEFAULT_EN_INPUT_BOUND (1)
#define TINY_DEFAULT_EN_STATE_SOC (0)
#define TINY_DEFAULT_EN_INPUT_SOC (0)
#define TINY_DEFAULT_EN_STATE_LINEAR (0)
#define TINY_DEFAULT_EN_INPUT_LINEAR (0)

View File

@@ -0,0 +1,197 @@
#pragma once
#include <Eigen/Eigen>
// #include <Eigen/Core>
// #include <Eigen/LU>
using namespace Eigen;
#ifdef __cplusplus
extern "C" {
#endif
typedef double tinytype; // should be double if you want to generate code
typedef Matrix<tinytype, Dynamic, Dynamic> tinyMatrix;
typedef Matrix<tinytype, Dynamic, 1> tinyVector;
// typedef Matrix<tinytype, NSTATES, 1> tiny_VectorNx;
// typedef Matrix<tinytype, NINPUTS, 1> tiny_VectorNu;
// typedef Matrix<tinytype, NSTATES, NSTATES> tiny_MatrixNxNx;
// typedef Matrix<tinytype, NSTATES, NINPUTS> tiny_MatrixNxNu;
// typedef Matrix<tinytype, NINPUTS, NSTATES> tiny_MatrixNuNx;
// typedef Matrix<tinytype, NINPUTS, NINPUTS> tiny_MatrixNuNu;
// typedef Matrix<tinytype, NSTATES, NHORIZON> tiny_MatrixNxNh; // Nu x Nh
// typedef Matrix<tinytype, NINPUTS, NHORIZON - 1> tiny_MatrixNuNhm1; // Nu x Nh-1
/**
* Solution
*/
typedef struct {
int iter;
int solved;
tinyMatrix x; // nx x N
tinyMatrix u; // nu x N-1
} TinySolution;
/**
* Matrices that must be recomputed with changes in time step, rho
*/
typedef struct {
tinytype rho;
tinyMatrix Kinf; // nu x nx
tinyMatrix Pinf; // nx x nx
tinyMatrix Quu_inv; // nu x nu
tinyMatrix AmBKt; // nx x nx
tinyVector APf; // nx x 1
tinyVector BPf; // nu x 1
tinyMatrix C1; // From adaptive rho
tinyMatrix C2; // From adaptive rho
// Sensitivity matrices for adaptive rho
tinyMatrix dKinf_drho;
tinyMatrix dPinf_drho;
tinyMatrix dC1_drho;
tinyMatrix dC2_drho;
} TinyCache;
/**
* User settings
*/
typedef struct {
tinytype abs_pri_tol;
tinytype abs_dua_tol;
int max_iter;
int check_termination;
int en_state_bound;
int en_input_bound;
int en_state_soc;
int en_input_soc;
int en_state_linear;
int en_input_linear;
// Add adaptive rho parameters
int adaptive_rho; // Enable/disable adaptive rho (1/0)
tinytype adaptive_rho_min; // Minimum value for rho
tinytype adaptive_rho_max; // Maximum value for rho
int adaptive_rho_enable_clipping; // Enable/disable clipping of rho (1/0)
} TinySettings;
/**
* Problem variables
*/
typedef struct {
int nx; // Number of states
int nu; // Number of control inputs
int N; // Number of knotpoints in the horizon
// State and input
tinyMatrix x; // nx x N
tinyMatrix u; // nu x N-1
// Linear control cost terms
tinyMatrix q; // nx x N
tinyMatrix r; // nu x N-1
// Linear Riccati backward pass terms
tinyMatrix p; // nx x N
tinyMatrix d; // nu x N-1
// Bound constraint variables
// Slack variables
tinyMatrix v; // nx x N
tinyMatrix vnew; // nx x N
tinyMatrix z; // nu x N-1
tinyMatrix znew; // nu x N-1
// Dual variables
tinyMatrix g; // nx x N
tinyMatrix y; // nu x N-1
// State and input bounds
tinyMatrix x_min; // nx x N
tinyMatrix x_max; // nx x N
tinyMatrix u_min; // nu x N-1
tinyMatrix u_max; // nu x N-1
// Cone constraint variables
// Variables to keep track of general cone information
int numStateCones; // Number of cone constraints on states at each time step
int numInputCones; // Number of cone constraints on inputs at each time step
tinyVector cx; // One coefficient for each state cone
tinyVector cu; // One coefficient for each input cone
VectorXi Acx; // Start indices for each state cone
VectorXi Acu; // Start indices for each input cone
VectorXi qcx; // Dimension for each state cone
VectorXi qcu; // Dimension for each input cone
// Slack variables
tinyMatrix vc; // nx x N
tinyMatrix vcnew; // nx x N
tinyMatrix zc; // nu x N-1
tinyMatrix zcnew; // nu x N-1
// Dual variables
tinyMatrix gc; // nx x N
tinyMatrix yc; // nu x N-1
// Linear constraint variables
// Variables to keep track of general linear constraint information
int numStateLinear; // Number of linear constraints on states at each time step
int numInputLinear; // Number of linear constraints on inputs at each time step
// Constraint matrices and vectors
tinyMatrix Alin_x; // Normal vectors for state linear constraints (numStateLinear x nx)
tinyVector blin_x; // Offset values for state linear constraints (numStateLinear x 1)
tinyMatrix Alin_u; // Normal vectors for input linear constraints (numInputLinear x nu)
tinyVector blin_u; // Offset values for input linear constraints (numInputLinear x 1)
// Slack variables for linear constraints
tinyMatrix vl; // nx x N
tinyMatrix vlnew; // nx x N
tinyMatrix zl; // nu x N-1
tinyMatrix zlnew; // nu x N-1
// Dual variables for linear constraints
tinyMatrix gl; // nx x N
tinyMatrix yl; // nu x N-1
// Q, R, A, B, f given by user
tinyVector Q; // nx x 1
tinyVector R; // nu x 1
tinyMatrix Adyn; // nx x nx (state transition matrix)
tinyMatrix Bdyn; // nx x nu (control matrix)
tinyVector fdyn; // nx x 1 (affine vector)
// Reference trajectory to track for one horizon
tinyMatrix Xref; // nx x N
tinyMatrix Uref; // nu x N-1
// Temporaries
tinyVector Qu; // nu x 1
// Variables for keeping track of solve status
tinytype primal_residual_state;
tinytype primal_residual_input;
tinytype dual_residual_state;
tinytype dual_residual_input;
int status;
int iter;
} TinyWorkspace;
/**
* Main TinyMPC solver structure that holds all information.
*/
typedef struct {
TinySolution* solution; // Solution
TinySettings* settings; // Problem settings
TinyCache* cache; // Problem cache
TinyWorkspace* work; // Solver workspace
} TinySolver;
// Add at the top with other definitions
#define BENCH_NX 12
#define BENCH_NU 4
#ifdef __cplusplus
}
#endif

View File

@@ -0,0 +1,111 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstddef>
#include <vector>
namespace wust_vision {
namespace auto_aim {
template<typename T>
concept HasStaticLerp = requires(const T& a, const T& b, double t) {
{
T::lerp(a, b, t)
} -> std::same_as<T>;
};
template<HasStaticLerp PointT>
class Trajectory {
public:
void reserve(size_t n) {
cp_vec.reserve(n);
dt_vec.reserve(n > 0 ? n - 1 : 0);
prefix_time.reserve(n);
}
void clear() {
cp_vec.clear();
dt_vec.clear();
prefix_time.clear();
total_duration_ = 0.0;
}
void push_back(const PointT& p, double dt = 0.0) {
if (cp_vec.empty()) {
cp_vec.push_back(p);
prefix_time.push_back(0.0);
total_duration_ = 0.0;
return;
}
assert(dt >= 0.0);
cp_vec.push_back(p);
dt_vec.push_back(dt);
total_duration_ += dt;
prefix_time.push_back(total_duration_);
}
void set(const std::vector<PointT>& c, const std::vector<double>& t) {
assert(!c.empty());
assert(c.size() == t.size() + 1);
cp_vec = c;
dt_vec = t;
prefix_time.resize(cp_vec.size());
prefix_time[0] = 0.0;
for (size_t i = 0; i < dt_vec.size(); ++i)
prefix_time[i + 1] = prefix_time[i] + dt_vec[i];
total_duration_ = prefix_time.back();
}
double getPrefixTimeAtIdx(int i) const {
return prefix_time[i];
}
PointT getStateAtIdx(int i) const {
return cp_vec[i];
}
PointT getStateAtTime(double t) const {
if (cp_vec.empty())
return PointT {};
if (t <= 0.0)
return cp_vec.front();
if (t >= total_duration_)
return cp_vec.back();
auto it = std::lower_bound(prefix_time.begin(), prefix_time.end(), t);
size_t i1 = std::distance(prefix_time.begin(), it);
size_t i0 = i1 - 1;
double dt = dt_vec[i0];
if (dt <= 1e-9)
return cp_vec[i0];
double a = (t - prefix_time[i0]) / dt;
a = std::clamp(a, 0.0, 1.0);
return PointT::lerp(cp_vec[i0], cp_vec[i1], a);
}
double getTotalDuration() const {
return total_duration_;
}
size_t size() const {
return cp_vec.size();
}
std::vector<PointT> cp_vec;
std::vector<double> dt_vec;
std::vector<double> prefix_time;
double total_duration_ { 0.0 };
};
} // namespace auto_aim
} // namespace wust_vision

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,27 @@
#pragma once
#include <memory>
namespace wust_vl::common::utils {
class Parameter;
}
using wust_vlParamterPtr = std::shared_ptr<wust_vl::common::utils::Parameter>;
namespace wust_vision {
struct GimbalCmd;
}
namespace wust_vision::auto_aim {
enum class AutoAimFsm;
class Target;
class VeryAimer {
public:
using Ptr = std::shared_ptr<VeryAimer>;
VeryAimer(wust_vlParamterPtr auto_aim_config_parameter);
static Ptr create(wust_vlParamterPtr auto_aim_config_parameter) {
return std::make_shared<VeryAimer>(auto_aim_config_parameter);
};
~VeryAimer();
[[nodiscard]] GimbalCmd
veryAim(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm);
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,70 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/type.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_aim {
struct LightParams {
// width / height
double min_ratio;
double max_ratio;
// vertical angle
double max_angle;
// judge color
int color_diff_thresh;
double max_angle_diff;
int binary_thres;
void load(const YAML::Node& config) {
binary_thres = config["binary_thres"].as<int>();
min_ratio = config["min_ratio"].as<double>();
max_ratio = config["max_ratio"].as<double>();
max_angle = config["max_angle"].as<double>();
max_angle_diff = config["max_angle_diff"].as<double>();
color_diff_thresh = config["color_diff_thresh"].as<int>();
}
};
struct ArmorParams {
double min_light_ratio;
// light pairs distance
double min_small_center_distance;
double max_small_center_distance;
double min_large_center_distance;
double max_large_center_distance;
// horizontal angle
double max_angle;
void load(const YAML::Node& config) {
min_light_ratio = config["min_light_ratio"].as<double>();
min_small_center_distance = config["min_small_center_distance"].as<double>();
max_small_center_distance = config["max_small_center_distance"].as<double>();
min_large_center_distance = config["min_large_center_distance"].as<double>();
max_large_center_distance = config["max_large_center_distance"].as<double>();
max_angle = config["max_angle"].as<double>();
}
};
class ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorBase>;
virtual ~ArmorDetectorBase() = default;
virtual void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) = 0;
using DetectorCallback =
std::function<void(const std::vector<ArmorObject>&, const CommonFrame&)>;
virtual void setCallback(DetectorCallback cb) = 0;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,473 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
#include "tasks/auto_aim/armor_detect/number_classifier/factory.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/common/utils/timer.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorCommon::Impl {
public:
Impl(const YAML::Node& config) {
params_.load(config);
number_classifier_ = NumberClassifierFactory::createNumberClassifier(
params_.classify_backend,
params_.classify_model_path,
params_.classify_label_path
);
}
bool extractNetImage(const cv::Mat& src, ArmorObject& armor) const noexcept {
constexpr int light_length = 12;
constexpr int warp_height = 28;
constexpr int small_armor_width = 32;
constexpr int large_armor_width = 54;
const cv::Size roi_size(20, 28);
if (src.empty() || src.cols < 10 || src.rows < 10) {
std::cerr << "[extractNetImage] input src is empty or too small!" << std::endl;
return false;
}
const auto ordered = armor.sortCorners(armor.pts);
const cv::Point2f& p0 = ordered[0];
const cv::Point2f& p1 = ordered[1];
const cv::Point2f& p2 = ordered[2];
const cv::Point2f& p3 = ordered[3];
const float l1_len = cv::norm(p1 - p0);
const float l2_len = cv::norm(p2 - p3);
const cv::Point2f c1 = (p0 + p1) * 0.5f;
const cv::Point2f c2 = (p2 + p3) * 0.5f;
const float avg_light_len = 0.5f * (l1_len + l2_len);
const float center_dist =
avg_light_len > 1e-3f ? cv::norm(c1 - c2) / avg_light_len : 0.f;
const bool is_large = center_dist > params_.armor_params.min_large_center_distance;
const cv::Rect bbox = cv::boundingRect(armor.pts);
if (bbox.width <= 0 || bbox.height <= 0)
return false;
if (bbox.width > src.cols || bbox.height > src.rows)
return false;
const int dw = static_cast<int>(bbox.width * (params_.expand_ratio_w - 1.f));
const int dh = static_cast<int>(bbox.height * (params_.expand_ratio_h - 1.f));
int new_x = bbox.x - (dw >> 1);
int new_y = bbox.y - (dh >> 1);
new_x = std::max(new_x, 0);
new_y = std::max(new_y, 0);
int new_w = std::min(bbox.width + dw, src.cols - new_x);
int new_h = std::min(bbox.height + dh, src.rows - new_y);
if (new_w <= 0 || new_h <= 0)
return false;
const cv::Rect expanded_rect(new_x, new_y, new_w, new_h);
cv::Mat litroi_color = src(expanded_rect);
if (litroi_color.empty())
return false;
cv::Mat litroi_gray;
try {
cv::cvtColor(litroi_color, litroi_gray, cv::COLOR_BGR2GRAY);
} catch (...) {
return false;
}
armor.whole_gray_img = litroi_gray;
if (params_.enable_cv) {
cv::Mat litroi_binary;
try {
cv::threshold(
litroi_gray,
litroi_binary,
params_.light_params.binary_thres,
255,
cv::THRESH_BINARY
);
armor.whole_binary_img = litroi_binary;
} catch (...) {
return false;
}
}
const cv::Point2f offset(static_cast<float>(new_x), static_cast<float>(new_y));
if (params_.enable_classify) {
cv::Point2f src_vertices[4] = { armor.pts[1] - offset,
armor.pts[0] - offset,
armor.pts[3] - offset,
armor.pts[2] - offset };
const int warp_width = is_large ? large_armor_width : small_armor_width;
const int top_light_y = (warp_height - light_length) / 2 - 1;
const int bottom_light_y = top_light_y + light_length;
if (warp_width <= 0 || warp_height <= 0)
return false;
cv::Point2f dst_vertices[4] = {
{ 0.f, static_cast<float>(bottom_light_y) },
{ 0.f, static_cast<float>(top_light_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(top_light_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(bottom_light_y) }
};
const cv::Mat warp_mat = cv::getPerspectiveTransform(src_vertices, dst_vertices);
cv::Mat number_image;
cv::warpPerspective(
litroi_gray,
number_image,
warp_mat,
cv::Size(warp_width, warp_height),
cv::INTER_LINEAR,
cv::BORDER_CONSTANT,
0
);
const int roi_x = (warp_width - roi_size.width) >> 1;
const cv::Rect num_roi(roi_x, 0, roi_size.width, roi_size.height);
if ((num_roi & cv::Rect(0, 0, warp_width, warp_height)) != num_roi)
return false;
cv::Mat num_crop = number_image(num_roi);
cv::threshold(
num_crop,
armor.number_img,
0,
255,
cv::THRESH_BINARY | cv::THRESH_OTSU
);
}
armor.whole_rgb_img = litroi_color;
armor.local_offset = offset;
return true;
}
bool refineLightsFromArmorPts(ArmorObject& armor) const noexcept {
armor.center = (armor.pts[0] + armor.pts[1] + armor.pts[2] + armor.pts[3]) * 0.25f;
const int n_lights = static_cast<int>(armor.lights.size());
if (n_lights < 2)
return false;
const auto ordered = armor.sortCorners(armor.pts);
const cv::Point2f ref_centers[2] = { (ordered[0] + ordered[1]) * 0.5f,
(ordered[2] + ordered[3]) * 0.5f };
int best0 = -1, best1 = -1;
float best0_d2 = std::numeric_limits<float>::max();
float best1_d2 = std::numeric_limits<float>::max();
for (int i = 0; i < n_lights; ++i) {
const cv::Point2f& c = armor.lights[i].center;
const cv::Point2f d0 = c - ref_centers[0];
const float dist0 = d0.dot(d0);
if (dist0 < best0_d2) {
best0_d2 = dist0;
best0 = i;
}
const cv::Point2f d1 = c - ref_centers[1];
const float dist1 = d1.dot(d1);
if (dist1 < best1_d2) {
best1_d2 = dist1;
best1 = i;
}
}
if (best0 == best1) {
best1 = -1;
best1_d2 = std::numeric_limits<float>::max();
for (int i = 0; i < n_lights; ++i) {
if (i == best0)
continue;
const cv::Point2f d = armor.lights[i].center - ref_centers[1];
const float dist = d.dot(d);
if (dist < best1_d2) {
best1_d2 = dist;
best1 = i;
}
}
}
if (best0 < 0 || best1 < 0)
return false;
const auto& l0 = armor.lights[best0];
const auto& l1 = armor.lights[best1];
if (l0.center.x < l1.center.x) {
armor.lights[0] = l0;
armor.lights[1] = l1;
} else {
armor.lights[0] = l1;
armor.lights[1] = l0;
}
return true;
}
std::vector<Light>
findLights(const cv::Mat& color_img, const cv::Mat& binary_img, ArmorObject& armor)
const noexcept {
std::vector<std::vector<cv::Point>> contours;
contours.reserve(64);
cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
std::vector<Light> all_lights;
all_lights.reserve(contours.size());
for (const auto& contour: contours) {
const int n = static_cast<int>(contour.size());
if (n < 6)
continue;
Light light(contour);
if (!isLight(light))
continue;
int sum_r = 0;
int sum_b = 0;
for (const auto& pt: contour) {
const cv::Vec3b* row = color_img.ptr<cv::Vec3b>(pt.y);
const cv::Vec3b& pix = row[pt.x];
sum_r += pix[0];
sum_b += pix[2];
}
const int avg_diff = std::abs(sum_r - sum_b) / n;
if (avg_diff <= params_.light_params.color_diff_thresh)
continue;
light.color = (sum_r > sum_b) ? 0 : 1; // 0=红, 1=蓝
all_lights.emplace_back(std::move(light));
}
std::sort(all_lights.begin(), all_lights.end(), [](const Light& a, const Light& b) {
return a.center.x < b.center.x;
});
armor.lights = all_lights;
return all_lights;
}
bool isLight(const Light& light) const noexcept {
// width / length 比例
const float ratio = light.width / light.length;
if (ratio <= params_.light_params.min_ratio || ratio >= params_.light_params.max_ratio)
return false;
if (light.tilt_angle >= params_.light_params.max_angle)
return false;
return true;
}
bool isArmor(const Light& l1, const Light& l2) const noexcept {
const float len1 = l1.length;
const float len2 = l2.length;
if (len1 <= 1e-3f || len2 <= 1e-3f)
return false;
const float min_len = (len1 < len2) ? len1 : len2;
const float max_len = (len1 < len2) ? len2 : len1;
if (min_len / max_len <= params_.armor_params.min_light_ratio)
return false;
const cv::Point2f d = l1.center - l2.center;
const float dist2 = d.dot(d);
const float avg_len = 0.5f * (len1 + len2);
const float min_small = params_.armor_params.min_small_center_distance * avg_len;
const float max_small = params_.armor_params.max_small_center_distance * avg_len;
const float min_large = params_.armor_params.min_large_center_distance * avg_len;
const float max_large = params_.armor_params.max_large_center_distance * avg_len;
const float min_small2 = min_small * min_small;
const float max_small2 = max_small * max_small;
const float min_large2 = min_large * min_large;
const float max_large2 = max_large * max_large;
const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2);
const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2);
if (!(small_ok || large_ok))
return false;
static const float tan_max_angle =
std::tan(params_.armor_params.max_angle * CV_PI / 180.0f);
if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle)
return false;
if (l1.color != l2.color)
return false;
return true;
}
std::vector<ArmorObject> detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number
) const noexcept {
std::vector<ArmorObject> armors;
if (!src_img.data || src_img.empty()) {
std::cout << "img data nullptr or empty" << std::endl;
return armors;
}
if (objs_result.empty()) {
return armors;
}
for (auto& armor_in: objs_result) {
ArmorObject armor = armor_in;
if ((detect_color == 0 && armor.color == ArmorColor::BLUE)
|| (detect_color == 1 && armor.color == ArmorColor::RED))
{
continue;
}
if (params_.enable_classify || params_.enable_cv) {
bool ok = false;
ok = extractNetImage(src_img, armor);
if (!ok)
continue;
}
if (params_.enable_classify) {
number_classifier_->classifyNumber(armor);
if (armor.confidence < params_.classifier_threshold)
continue;
}
if (target_number.has_value()) {
if (!isSameTarget(target_number.value(), armor.number)) {
continue;
}
}
if (armor.color == ArmorColor::NONE || armor.color == ArmorColor::PURPLE) {
armor.is_ok = false;
armor.transform(transform_matrix);
armors.emplace_back(armor);
continue;
}
if (params_.enable_cv) {
findLights(armor.whole_rgb_img, armor.whole_binary_img, armor);
if (refineLightsFromArmorPts(armor)) {
if (isArmor(armor.lights[0], armor.lights[1])) {
armor.is_ok = true;
for (auto& light: armor.lights) {
light.addOffset(armor.local_offset);
}
}
}
if (armor.is_ok) {
armor.is_ok = armor.checkOkptsRight(params_.max_pts_error);
}
}
if (!armor.is_ok) {
auto ordered = armor.sortCorners(armor.pts);
Light l1, l2;
l1.length = cv::norm(ordered[1] - ordered[0]);
l1.center = (ordered[0] + ordered[1]) / 2.0;
l2.length = cv::norm(ordered[2] - ordered[3]);
l2.center = (ordered[2] + ordered[3]) / 2.0;
if (!isArmor(l1, l2)) {
continue;
}
}
armor.transform(transform_matrix);
armors.emplace_back(armor);
}
return armors;
}
std::unique_ptr<NumberClassifierBase> number_classifier_;
struct ArmorDetectCommonParams {
std::string classify_backend = "opencv";
std::string classify_model_path;
std::string classify_label_path;
double classifier_threshold = 0.5;
LightParams light_params;
ArmorParams armor_params;
float expand_ratio_w = 1.1f;
float expand_ratio_h = 1.1f;
double max_pts_error = 20.0;
bool enable_cv = false;
bool enable_classify = true;
void load(const YAML::Node& config) {
expand_ratio_w = config["cv"]["light"]["expand_ratio_w"].as<float>(1.1);
expand_ratio_h = config["cv"]["light"]["expand_ratio_h"].as<float>(1.1);
max_pts_error = config["cv"]["light"]["max_pts_error"].as<double>(20.0);
enable_cv = config["cv"]["enable"].as<bool>();
light_params.load(config["cv"]["light"]);
armor_params.load(config["cv"]["armor"]);
enable_classify = config["classify"]["enable"].as<bool>();
classify_model_path =
utils::expandEnv(config["classify"]["model_path"].as<std::string>());
classify_label_path =
utils::expandEnv(config["classify"]["label_path"].as<std::string>());
classify_backend = config["classify"]["backend"].as<std::string>();
classifier_threshold = config["classify"]["threshold"].as<double>();
}
} params_;
};
ArmorDetectorCommon::ArmorDetectorCommon(const YAML::Node& config) {
_impl = std::make_unique<Impl>(config);
}
ArmorDetectorCommon::~ArmorDetectorCommon() {
_impl.reset();
}
std::vector<ArmorObject> ArmorDetectorCommon::detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number
) {
return _impl
->detectNet(src_img, objs_result, transform_matrix, detect_color, target_number);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,41 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorCommon {
public:
using Ptr = std::unique_ptr<ArmorDetectorCommon>;
ArmorDetectorCommon(const YAML::Node& config);
static Ptr create(const YAML::Node& config) {
return std::make_unique<ArmorDetectorCommon>(config);
}
~ArmorDetectorCommon();
std::vector<ArmorObject> detectNet(
const cv::Mat& src_img,
std::vector<ArmorObject>& objs_result,
Eigen::Matrix3f transform_matrix,
int detect_color,
const std::optional<ArmorNumber>& target_number = std::nullopt
);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,134 @@
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "armor_detector_base.hpp"
#include "tasks/utils/config.hpp"
#include <string>
#include <yaml-cpp/yaml.h>
#ifdef USE_OPENVINO
#include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp"
#endif
#ifdef USE_TRT
#include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp"
#endif
#ifdef USE_NCNN
#include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp"
#endif
#ifdef USE_ORT
#include "tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp"
#endif
#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp"
namespace wust_vision {
namespace auto_aim {
class DetectorFactory {
public:
static ArmorDetectorBase::Ptr createArmorDetector(
const std::string& backend,
bool use_armor_detect_common,
std::string cv_config_path = OPENCV_CONFIG,
std::string ml_config_path = ML_CONFIG
) {
// 检查编译时是否支持
auto isBackendEnabled = [&backend]() -> bool {
#ifdef USE_OPENVINO
if (backend == "openvino")
return true;
#endif
#ifdef USE_TRT
if (backend == "tensorrt")
return true;
#endif
#ifdef USE_NCNN
if (backend == "ncnn")
return true;
#endif
#ifdef USE_ORT
if (backend == "onnxruntime")
return true;
#endif
if (backend == "opencv")
return true;
return false;
};
if (!isBackendEnabled()) {
std::cout << "Backend " << backend << " is not enabled at compile time."
<< std::endl;
throw std::runtime_error("Backend " + backend + " is not enabled at compile time.");
}
auto getConfigPath = [&](const std::string& backend) -> std::string {
if (backend == "opencv")
return cv_config_path;
else
return ml_config_path;
};
std::string config_path = getConfigPath(backend);
if (config_path.empty()) {
std::cout << "No config path for backend: " << backend << std::endl;
throw std::runtime_error("No config path for backend: " + backend);
}
YAML::Node armor_detect_config = YAML::LoadFile(config_path);
// 创建对应后端实例
#if defined(USE_OPENVINO)
if (backend == "openvino") {
return ArmorDetectorOpenVino::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_TRT)
if (backend == "tensorrt") {
return ArmorDetectorTrt::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_NCNN)
if (backend == "ncnn") {
return ArmorDetectorNCNN::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
#if defined(USE_ORT)
if (backend == "onnxruntime") {
return ArmorDetectorOnnxRuntime::create(
armor_detect_config["armor_detector"],
use_armor_detect_common
);
}
#endif
if (backend == "opencv") {
return ArmorDetectorOpenCV::create(armor_detect_config["armor_detector"]);
}
std::cout << "Unsupported armor detector backend (or not compiled): " << backend
<< std::endl;
throw std::runtime_error(
"Unsupported armor detector backend (or not compiled): " + backend
);
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,198 @@
#include "armor_infer.hpp"
namespace wust_vision::auto_aim::armor_infer {
struct GridAndStride {
int grid0;
int grid1;
int stride;
};
[[nodiscard]] static inline std::vector<GridAndStride>
generate_grids_and_stride(int target_w, int target_h, const std::vector<int>& strides) noexcept {
std::vector<GridAndStride> grid_strides;
for (int stride: strides) {
const int num_w = target_w / stride;
const int num_h = target_h / stride;
grid_strides.reserve(grid_strides.size() + num_w * num_h);
for (int gy = 0; gy < num_h; ++gy) {
for (int gx = 0; gx < num_w; ++gx) {
grid_strides.push_back(GridAndStride { gx, gy, stride });
}
}
}
return grid_strides;
}
std::vector<ArmorObject> ArmorInfer::postProcessTUP_impl(const cv::Mat& out) const {
static std::optional<std::vector<GridAndStride>> _grid_strides;
if (!_grid_strides) {
_grid_strides = generate_grids_and_stride(inputW(), inputH(), { 8, 16, 32 });
}
const auto& grid_strides = _grid_strides.value();
std::vector<ArmorObject> out_objs;
const int num_anchors =
static_cast<int>(std::min<size_t>(grid_strides.size(), static_cast<size_t>(out.rows)));
for (int a = 0; a < num_anchors; ++a) {
const float confidence = out.at<float>(a, 8);
if (confidence < conf_threshold_)
continue;
const auto& gs = grid_strides[a];
const int gx = gs.grid0, gy = gs.grid1, stride = gs.stride;
// color & class
const int color_offset = 9;
const int num_colors = ModelTraits<Mode::TUP>::NUM_COLORS;
const int num_classes = ModelTraits<Mode::TUP>::NUM_CLASSES;
cv::Mat color_scores = out.row(a).colRange(color_offset, color_offset + num_colors);
cv::Mat class_scores =
out.row(a).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
double max_color, max_class;
cv::Point color_id, class_id;
cv::minMaxLoc(color_scores, nullptr, &max_color, nullptr, &color_id);
cv::minMaxLoc(class_scores, nullptr, &max_class, nullptr, &class_id);
const float x1 = (out.at<float>(a, 0) + gx) * stride;
const float y1 = (out.at<float>(a, 1) + gy) * stride;
const float x2 = (out.at<float>(a, 2) + gx) * stride;
const float y2 = (out.at<float>(a, 3) + gy) * stride;
const float x3 = (out.at<float>(a, 4) + gx) * stride;
const float y3 = (out.at<float>(a, 5) + gy) * stride;
const float x4 = (out.at<float>(a, 6) + gx) * stride;
const float y4 = (out.at<float>(a, 7) + gy) * stride;
ArmorObject obj;
obj.pts = { cv::Point2f(x1, y1),
cv::Point2f(x2, y2),
cv::Point2f(x3, y3),
cv::Point2f(x4, y4) };
obj.box = cv::boundingRect(obj.pts);
obj.color = static_cast<ArmorColor>(color_id.x);
obj.number = static_cast<ArmorNumber>(class_id.x);
obj.confidence = confidence;
out_objs.push_back(std::move(obj));
}
return topKAndNms(out_objs, top_k_, nms_threshold_);
}
std::vector<ArmorObject> ArmorInfer::postProcessRP_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
const int rows = out.rows;
const int color_offset = 9;
const int num_colors = ModelTraits<Mode::RP>::NUM_COLORS;
const int num_classes = ModelTraits<Mode::RP>::NUM_CLASSES;
for (int r = 0; r < rows; ++r) {
float conf_raw = out.at<float>(r, 8);
const float confidence = static_cast<float>(sigmoid(conf_raw));
if (confidence < conf_threshold_)
continue;
cv::Mat color_scores = out.row(r).colRange(color_offset, color_offset + num_colors);
cv::Mat class_scores =
out.row(r).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
double max_color_score, max_class_score;
cv::Point color_id, class_id;
cv::minMaxLoc(color_scores, nullptr, &max_color_score, nullptr, &color_id);
cv::minMaxLoc(class_scores, nullptr, &max_class_score, nullptr, &class_id);
ArmorObject obj;
obj.pts.resize(4);
for (int k = 0; k < 4; ++k) {
const float x = out.at<float>(r, 0 + k * 2);
const float y = out.at<float>(r, 1 + k * 2);
obj.pts[k] = cv::Point2f(x, y);
}
obj.box = cv::boundingRect(obj.pts);
obj.color = static_cast<ArmorColor>(color_id.x);
obj.number = static_cast<ArmorNumber>(class_id.x);
obj.confidence = confidence;
out_objs.push_back(std::move(obj));
}
return topKAndNms(out_objs, top_k_, nms_threshold_);
}
std::vector<ArmorObject> ArmorInfer::postProcessAT_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
constexpr int nkpt = ModelTraits<Mode::AT>::NUM_KPTS;
constexpr int nk = nkpt * 2; // keypoints flattened
auto max_det = out.rows;
auto det_dim = out.cols;
auto output_ptr = out.ptr<float>();
for (int i = 0; i < max_det; ++i) {
const float* row = output_ptr + i * det_dim;
float conf = row[4];
if (!std::isfinite(conf) || conf < conf_threshold_)
continue;
float x1 = row[0];
float y1 = row[1];
float x2 = row[2];
float y2 = row[3];
int cls = static_cast<int>(row[5]);
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|| x2 <= x1 || y2 <= y1)
continue;
ArmorObject obj;
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
obj.confidence = conf;
auto color_num = ModelTraits<Mode::AT>::CLASSES[cls];
obj.color = color_num.first;
obj.number = color_num.second;
obj.pts.reserve(nkpt);
for (int k = 0; k < nkpt; ++k) {
float kx = row[6 + 2 * k];
float ky = row[6 + 2 * k + 1];
obj.pts.emplace_back(kx, ky);
}
out_objs.emplace_back(std::move(obj));
}
return out_objs;
}
std::vector<ArmorObject> ArmorInfer::postProcessBOX_impl(const cv::Mat& out) const {
std::vector<ArmorObject> out_objs;
auto max_det = out.rows;
auto det_dim = out.cols;
auto output_ptr = out.ptr<float>();
for (int i = 0; i < max_det; ++i) {
const float* row = output_ptr + i * det_dim;
float conf = row[4];
if (!std::isfinite(conf) || conf < conf_threshold_)
continue;
float x1 = row[0];
float y1 = row[1];
float x2 = row[2];
float y2 = row[3];
int cls = static_cast<int>(row[5]);
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|| x2 <= x1 || y2 <= y1)
continue;
ArmorObject obj;
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
obj.confidence = conf;
auto color_num = ModelTraits<Mode::BOX>::CLASSES[cls];
obj.color = color_num.first;
obj.number = color_num.second;
std::vector<cv::Point2f> pts;
pts.resize(4);
pts[0] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y + obj.box.height); // 右下
pts[1] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y); // 右上
pts[2] = cv::Point2f(obj.box.x, obj.box.y); // 左上
pts[3] = cv::Point2f(obj.box.x, obj.box.y + obj.box.height); // 左下
obj.pts = std::move(pts);
out_objs.emplace_back(std::move(obj));
}
return out_objs;
}
} // namespace wust_vision::auto_aim::armor_infer

View File

@@ -0,0 +1,324 @@
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision::auto_aim::armor_infer {
static constexpr float MERGE_CONF_ERROR = 0.15f;
static constexpr float MERGE_MIN_IOU = 0.9f;
enum class Mode { TUP, RP, AT, BOX416, BOX320, BOX };
inline Mode modeFromString(const std::string& m) {
if (m == "tup" || m == "TUP")
return Mode::TUP;
if (m == "rp" || m == "RP")
return Mode::RP;
if (m == "at" || m == "AT")
return Mode::AT;
if (m == "box416" || m == "BOX416")
return Mode::BOX416;
if (m == "box320" || m == "BOX320")
return Mode::BOX320;
return Mode::TUP;
}
// ------------------------- model traits -------------------------
template<Mode M>
struct ModelTraits; // declare
// TUP
template<>
struct ModelTraits<Mode::TUP> {
static constexpr int INPUT_W = 416;
static constexpr int INPUT_H = 416;
static constexpr int NUM_CLASSES = 8;
static constexpr int NUM_COLORS = 4;
static constexpr bool USE_NORM = false;
static constexpr bool INPUT_RGB = true;
};
// RP
template<>
struct ModelTraits<Mode::RP> {
static constexpr int INPUT_W = 640;
static constexpr int INPUT_H = 640;
static constexpr int NUM_CLASSES = 9;
static constexpr int NUM_COLORS = 4;
static constexpr bool USE_NORM = true;
static constexpr bool INPUT_RGB = false;
};
template<>
struct ModelTraits<Mode::AT> {
static constexpr int INPUT_W = 640;
static constexpr int INPUT_H = 640;
static constexpr int NUM_KPTS = 4;
static constexpr bool USE_NORM = true;
static constexpr bool INPUT_RGB = false;
static constexpr std::array<std::pair<ArmorColor, ArmorNumber>, 64> CLASSES = { {
{ ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 },
{ ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 },
{ ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 },
{ ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE },
{ ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 },
{ ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 },
{ ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 },
{ ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE },
{ ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 },
{ ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 },
{ ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 },
{ ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE },
{ ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 },
{ ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 },
{ ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 },
{ ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE },
{ ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 },
{ ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 },
{ ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 },
{ ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE },
{ ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 },
{ ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 },
{ ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 },
{ ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE },
{ ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 },
{ ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 },
{ ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 },
{ ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE },
{ ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 },
{ ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 },
{ ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 },
{ ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE },
} };
};
template<>
struct ModelTraits<Mode::BOX416> {
static constexpr int INPUT_W = 416;
static constexpr int INPUT_H = 416;
};
template<>
struct ModelTraits<Mode::BOX320> {
static constexpr int INPUT_W = 320;
static constexpr int INPUT_H = 320;
};
template<>
struct ModelTraits<Mode::BOX> {
static constexpr int INPUT_W = 416;
static constexpr int INPUT_H = 416;
static constexpr bool INPUT_RGB = false;
static constexpr bool USE_NORM = true;
static constexpr std::array<std::pair<ArmorColor, ArmorNumber>, 12> CLASSES = { {
{ ArmorColor::BLUE, ArmorNumber::NO1 },
{ ArmorColor::BLUE, ArmorNumber::NO2 },
{ ArmorColor::BLUE, ArmorNumber::NO3 },
{ ArmorColor::BLUE, ArmorNumber::NO4 },
{ ArmorColor::BLUE, ArmorNumber::NO5 },
{ ArmorColor::BLUE, ArmorNumber::SENTRY },
{ ArmorColor::RED, ArmorNumber::NO1 },
{ ArmorColor::RED, ArmorNumber::NO2 },
{ ArmorColor::RED, ArmorNumber::NO3 },
{ ArmorColor::RED, ArmorNumber::NO4 },
{ ArmorColor::RED, ArmorNumber::NO5 },
{ ArmorColor::RED, ArmorNumber::SENTRY },
} };
};
[[nodiscard]] inline double sigmoid(double x) noexcept {
return x >= 0 ? 1.0 / (1.0 + std::exp(-x)) : std::exp(x) / (1.0 + std::exp(x));
}
[[nodiscard]] inline float rectIoU(const cv::Rect2f& a, const cv::Rect2f& b) noexcept {
const cv::Rect2f inter = a & b;
const float inter_area = inter.area();
const float union_area = a.area() + b.area() - inter_area;
if (union_area <= 0.f || std::isnan(union_area))
return 0.f;
return inter_area / union_area;
}
// Merge / NMS helpers that mimic original intent but clearer
inline void nms_merge_sorted_bboxes(
std::vector<ArmorObject>& objs,
std::vector<int>& out_indices,
float nms_threshold
) {
out_indices.clear();
const size_t n = objs.size();
std::vector<float> areas(n);
for (size_t i = 0; i < n; ++i)
areas[i] = objs[i].box.area();
for (size_t i = 0; i < n; ++i) {
ArmorObject& a = objs[i];
bool keep = true;
for (int idx: out_indices) {
ArmorObject& b = objs[idx];
const float iou = rectIoU(a.box, b.box);
if (std::isnan(iou) || iou > nms_threshold) {
keep = false;
if (a.number == b.number && a.color == b.color && iou > MERGE_MIN_IOU
&& std::abs(a.confidence - b.confidence) < MERGE_CONF_ERROR)
{
// accumulate points for later averaging
for (const auto& pt: a.pts)
b.pts.push_back(pt);
}
break;
}
}
if (keep)
out_indices.push_back(static_cast<int>(i));
}
}
inline std::vector<ArmorObject>
topKAndNms(std::vector<ArmorObject>& objs, int top_k, float nms_threshold) {
std::sort(objs.begin(), objs.end(), [](const ArmorObject& a, const ArmorObject& b) {
return a.confidence > b.confidence;
});
if (static_cast<int>(objs.size()) > top_k)
objs.resize(static_cast<size_t>(top_k));
std::vector<int> indices;
nms_merge_sorted_bboxes(objs, indices, nms_threshold);
std::vector<ArmorObject> result;
result.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
result.push_back(std::move(objs[indices[i]]));
// average merged extra points if any
auto& ro = result.back();
if (ro.pts.size() >= 8) {
const size_t npts = ro.pts.size();
std::array<cv::Point2f, 4> accum { { { 0, 0 }, { 0, 0 }, { 0, 0 }, { 0, 0 } } };
for (size_t j = 0; j < npts; ++j)
accum[j % 4] += ro.pts[j];
ro.pts.resize(4);
for (int k = 0; k < 4; ++k) {
float denom = static_cast<float>(npts) / 4.0f;
ro.pts[k].x = accum[k].x / denom;
ro.pts[k].y = accum[k].y / denom;
}
}
}
return result;
}
// However providing full modern unified class below that delegates using templated helpers.
class ArmorInfer {
public:
ArmorInfer(
Mode mode = Mode::TUP,
float conf_threshold = 0.25f,
float nms_threshold = 0.45f,
int top_k = 100
) noexcept:
mode_(mode),
conf_threshold_(conf_threshold),
nms_threshold_(nms_threshold),
top_k_(top_k) {
setMode(mode_);
}
void setMode(Mode m) noexcept {
mode_ = m;
switch (mode_) {
case Mode::TUP: {
input_w_ = ModelTraits<Mode::TUP>::INPUT_W;
input_h_ = ModelTraits<Mode::TUP>::INPUT_H;
use_norm_ = ModelTraits<Mode::TUP>::USE_NORM;
input_rgb_ = ModelTraits<Mode::TUP>::INPUT_RGB;
break;
}
case Mode::RP: {
input_w_ = ModelTraits<Mode::RP>::INPUT_W;
input_h_ = ModelTraits<Mode::RP>::INPUT_H;
use_norm_ = ModelTraits<Mode::RP>::USE_NORM;
input_rgb_ = ModelTraits<Mode::RP>::INPUT_RGB;
break;
}
case Mode::AT: {
input_w_ = ModelTraits<Mode::AT>::INPUT_W;
input_h_ = ModelTraits<Mode::AT>::INPUT_H;
use_norm_ = ModelTraits<Mode::AT>::USE_NORM;
input_rgb_ = ModelTraits<Mode::AT>::INPUT_RGB;
break;
}
case Mode::BOX416: {
input_w_ = ModelTraits<Mode::BOX416>::INPUT_W;
input_h_ = ModelTraits<Mode::BOX416>::INPUT_H;
use_norm_ = ModelTraits<Mode::BOX>::USE_NORM;
input_rgb_ = ModelTraits<Mode::BOX>::INPUT_RGB;
break;
}
case Mode::BOX320: {
input_w_ = ModelTraits<Mode::BOX320>::INPUT_W;
input_h_ = ModelTraits<Mode::BOX320>::INPUT_H;
use_norm_ = ModelTraits<Mode::BOX>::USE_NORM;
input_rgb_ = ModelTraits<Mode::BOX>::INPUT_RGB;
break;
} break;
}
}
void setConfThreshold(float t) noexcept {
conf_threshold_ = t;
}
void setNmsThreshold(float t) noexcept {
nms_threshold_ = t;
}
void setTopK(int k) noexcept {
top_k_ = k;
}
int inputW() const noexcept {
return input_w_;
}
int inputH() const noexcept {
return input_h_;
}
bool useNorm() const noexcept {
return use_norm_;
}
bool inputRGB() const noexcept {
return input_rgb_;
}
// main dispatching entry (keeps original signature)
[[nodiscard]] std::vector<ArmorObject> postProcess(const cv::Mat& output_buffer) const {
switch (mode_) {
case Mode::TUP:
return postProcessTUP_impl(output_buffer);
case Mode::RP:
return postProcessRP_impl(output_buffer);
case Mode::AT:
return postProcessAT_impl(output_buffer);
case Mode::BOX416:
return postProcessBOX_impl(output_buffer);
case Mode::BOX320:
return postProcessBOX_impl(output_buffer);
}
return {};
}
private:
std::vector<ArmorObject> postProcessTUP_impl(const cv::Mat& out) const;
std::vector<ArmorObject> postProcessRP_impl(const cv::Mat& out) const;
std::vector<ArmorObject> postProcessAT_impl(const cv::Mat& out) const;
std::vector<ArmorObject> postProcessBOX_impl(const cv::Mat& out) const;
private:
Mode mode_;
int input_w_ { 0 };
int input_h_ { 0 };
float conf_threshold_ { 0.25f };
float nms_threshold_ { 0.45f };
int top_k_ { 100 };
bool use_norm_ { false };
bool input_rgb_ { false };
};
} // namespace wust_vision::auto_aim::armor_infer

View File

@@ -0,0 +1,222 @@
#ifdef USE_NCNN
#include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "wust_vl/ml_net/ncnn/ncnn_net.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorNCNN::Impl {
public:
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = ArmorDetectorCommon::create(config);
}
std::string model_type = config["ncnn"]["model_type"].as<std::string>();
auto model = armor_infer::modeFromString(model_type);
float conf_threshold = config["ncnn"]["conf_threshold"].as<float>();
int top_k = config["ncnn"]["top_k"].as<int>();
float nms_threshold = config["ncnn"]["nms_threshold"].as<float>();
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
std::string model_path_param =
utils::expandEnv(config["ncnn"]["model_path_param"].as<std::string>());
std::string model_path_bin =
utils::expandEnv(config["ncnn"]["model_path_bin"].as<std::string>());
bool use_gpu = config["ncnn"]["use_gpu"].as<bool>();
int cpu_threads = config["ncnn"]["cpu_threads"].as<int>();
bool use_lightmode = config["ncnn"]["use_lightmode"].as<bool>();
auto input_name = config["ncnn"]["input_name"].as<std::string>();
auto output_name = config["ncnn"]["output_name"].as<std::string>();
int device_id = config["ncnn"]["device_id"].as<int>();
wust_vl::ml_net::NCNNNet::Params params;
params.model_path_param = model_path_param;
params.model_path_bin = model_path_bin;
params.input_name = input_name;
params.output_name = output_name;
params.use_vulkan = use_gpu;
params.device_id = device_id;
params.use_light_mode = use_lightmode;
params.cpu_threads = cpu_threads;
ncnn_net_ = std::make_unique<wust_vl::ml_net::NCNNNet>();
ncnn_net_->init(params);
}
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorNCNN>(config, use_armor_detect_common);
}
~Impl() {
armor_detect_common_.reset();
ncnn_net_.reset();
}
cv::Mat ncnnMatToCvMat(const ncnn::Mat& m) {
cv::Mat img(m.h, m.w, CV_8UC3);
m.to_pixels(img.data, ncnn::Mat::PIXEL_RGB2BGR);
return img;
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
static ncnn::Mat letterbox_to_ncnn(
const cv::Mat& img,
Eigen::Matrix3f& transform_matrix,
int out_w,
int out_h,
float norm,
bool swap_rb = true
) {
const int img_w = img.cols;
const int img_h = img.rows;
float scale = std::min(out_w * 1.0f / img_w, out_h * 1.0f / img_h);
int resize_w = static_cast<int>(round(img_w * scale));
int resize_h = static_cast<int>(round(img_h * scale));
int pad_w = out_w - resize_w;
int pad_h = out_h - resize_h;
int pad_left = static_cast<int>(round(pad_w / 2.0f - 0.1f));
int pad_top = static_cast<int>(round(pad_h / 2.0f - 0.1f));
transform_matrix << 1.0f / scale, 0, -pad_left / scale, 0, 1.0f / scale,
-pad_top / scale, 0, 0, 1;
ncnn::Mat out;
if (swap_rb) {
out = ncnn::Mat::from_pixels_resize(
img.data,
ncnn::Mat::PIXEL_BGR2RGB,
img_w,
img_h,
resize_w,
resize_h
);
} else {
out = ncnn::Mat::from_pixels_resize(
img.data,
ncnn::Mat::PIXEL_RGB,
img_w,
img_h,
resize_w,
resize_h
);
}
int pad_right = out_w - resize_w - pad_left;
int pad_bottom = out_h - resize_h - pad_top;
ncnn::Mat padded;
ncnn::copy_make_border(
out,
padded,
pad_top,
pad_bottom,
pad_left,
pad_right,
ncnn::BORDER_CONSTANT,
114.f
);
std::array<float, 3> mean_vals;
std::array<float, 3> norm_vals;
mean_vals = { 0.f, 0.f, 0.f };
norm_vals = { norm, norm, norm };
padded.substract_mean_normalize(mean_vals.data(), norm_vals.data());
return padded;
}
void
processCallback(const CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
// Eigen::Matrix3f transform_matrix;
// cv::Mat resized_img = letterbox(frame.src_img, transform_matrix);
// ncnn::Mat in =
// ncnn::Mat::from_pixels(resized_img.data, ncnn::Mat::PIXEL_BGR2RGB, INPUT_W, INPUT_H);
Eigen::Matrix3f transform_matrix;
auto roi = frame.img_frame.src_img(frame.expanded);
const bool swap_rb = armor_infer_->inputRGB()
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
ncnn::Mat in = letterbox_to_ncnn(
roi.clone(),
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH(),
scale,
swap_rb
);
cv::Mat resized_img = ncnnMatToCvMat(in);
auto out = ncnn_net_->infer(in);
cv::Mat output_buffer(out.h, out.w, CV_32F, out.data);
// Parse YOLO output
auto objs_result = armor_infer_->postProcess(output_buffer);
std::vector<ArmorObject> armors;
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
}
return;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
frame.id = current_id_++;
processCallback(frame, target_number);
}
private:
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
std::unique_ptr<wust_vl::ml_net::NCNNNet> ncnn_net_;
};
ArmorDetectorNCNN::ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorNCNN::~ArmorDetectorNCNN() {
_impl.reset();
}
void ArmorDetectorNCNN::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorNCNN::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,36 @@
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorNCNN: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorNCNN>;
explicit ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorNCNN>(config, use_armor_detect_common);
}
~ArmorDetectorNCNN();
void setCallback(DetectorCallback callback) override;
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,12 @@
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class NumberClassifierBase {
public:
virtual ~NumberClassifierBase() = default;
virtual bool classifyNumber(ArmorObject& armor) = 0;
virtual void initNumberClassifier() = 0;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,34 @@
#pragma once
#include "number_classifier.hpp"
#ifdef USE_TRT
#include "number_classifier_trt.hpp"
#endif
namespace wust_vision {
namespace auto_aim {
class NumberClassifierFactory {
public:
static std::unique_ptr<NumberClassifierBase> createNumberClassifier(
const std::string& backend,
const std::string& classify_model_path,
const std::string& classify_label_path
) {
#if defined(USE_TRT)
if (backend == "tensorrt") {
return std::make_unique<NumberClassifierTRT>(
classify_model_path,
classify_label_path
);
}
#endif
if (backend == "opencv") {
return std::make_unique<NumberClassifier>(classify_model_path, classify_label_path);
}
throw std::runtime_error(
"Unsupported number classifier backend (or not compiled): " + backend
);
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,116 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "number_classifier.hpp"
#include <wust_vl/common/utils/logger.hpp>
namespace wust_vision {
namespace auto_aim {
NumberClassifier::NumberClassifier(
const std::string& classify_model_path,
const std::string& classify_label_path
):
classify_model_path_(classify_model_path),
classify_label_path_(classify_label_path) {
initNumberClassifier();
}
void NumberClassifier::initNumberClassifier() {
const std::string model_path = classify_model_path_;
std::unique_ptr<cv::dnn::Net> number_net_ =
std::make_unique<cv::dnn::Net>(cv::dnn::readNetFromONNX(model_path));
if (number_net_->empty()) {
WUST_ERROR("number_classifier")
<< "Failed to load number classifier model from " << model_path;
std::exit(EXIT_FAILURE);
} else {
WUST_INFO("number_classifier")
<< "Successfully loaded number classifier model from " << model_path;
}
const std::string label_path = classify_label_path_;
std::ifstream label_file(label_path);
std::string line;
class_names_.clear();
while (std::getline(label_file, line)) {
class_names_.push_back(line);
}
if (class_names_.empty()) {
WUST_ERROR("number_classifier") << "Failed to load labels from " << label_path;
std::exit(EXIT_FAILURE);
} else {
WUST_INFO("number_classifier")
<< "Successfully loaded " << class_names_.size() << " labels from " << label_path;
}
number_net_.reset();
}
bool NumberClassifier::classifyNumber(ArmorObject& armor) {
static thread_local std::unique_ptr<cv::dnn::Net> thread_net;
if (armor.number_img.empty()) {
return false;
}
if (!thread_net) {
thread_net =
std::make_unique<cv::dnn::Net>(cv::dnn::readNetFromONNX(classify_model_path_));
WUST_DEBUG("number_classifier") << "Loaded number classifier model for this thread";
if (thread_net->empty()) {
WUST_ERROR("number_classifier")
<< "Failed to load thread-local number classifier model.";
return false;
}
}
const cv::Mat image = armor.number_img;
cv::Mat blob;
cv::dnn::blobFromImage(image, blob, 1.0 / 255.0);
thread_net->setInput(blob);
cv::Mat outputs = thread_net->forward();
double max_val;
cv::minMaxLoc(outputs, nullptr, &max_val);
cv::Mat prob;
cv::exp(outputs - max_val, prob);
prob /= cv::sum(prob)[0];
double confidence;
cv::Point class_id;
cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id);
const int label_id = class_id.x;
const double raw_conf = armor.confidence;
armor.confidence = confidence;
static const std::map<int, ArmorNumber> label_to_armor_number = {
{ 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 },
{ 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST },
{ 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE }
};
if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) {
armor.number = label_to_armor_number.at(label_id);
return true;
} else {
armor.confidence = raw_conf;
return false;
}
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,35 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base.hpp"
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class NumberClassifier: public NumberClassifierBase {
public:
NumberClassifier(
const std::string& classify_model_path,
const std::string& classify_label_path
);
void initNumberClassifier() override;
bool classifyNumber(ArmorObject& armor) override;
private:
std::vector<std::string> class_names_;
std::string classify_model_path_;
std::string classify_label_path_;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,118 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef USE_TRT
#include "number_classifier_trt.hpp"
#include <wust_vl/common/utils/logger.hpp>
namespace wust_vision {
namespace auto_aim {
NumberClassifierTRT::NumberClassifierTRT(
const std::string& classify_model_path,
const std::string& classify_label_path
):
classify_model_path_(classify_model_path),
classify_label_path_(classify_label_path) {
initNumberClassifier();
}
void NumberClassifierTRT::initNumberClassifier() {
const std::string model_path = classify_model_path_;
trt_net_ = std::make_unique<wust_vl::ml_net::TensorRTNet>();
wust_vl::ml_net::TensorRTNet::Params trt_params;
trt_params.model_path = model_path;
trt_params.input_dims = nvinfer1::Dims4 { 1, 1, 20, 28 };
trt_net_->init(trt_params);
const auto input_output_dims = trt_net_->getInputOutputDims();
input_dims_ = std::get<0>(input_output_dims);
output_dims_ = std::get<1>(input_output_dims);
const std::string label_path = classify_label_path_;
std::ifstream label_file(label_path);
std::string line;
class_names_.clear();
while (std::getline(label_file, line)) {
class_names_.push_back(line);
}
if (class_names_.empty()) {
WUST_ERROR("number_classifier_trt") << "Failed to load labels from " << label_path;
std::exit(EXIT_FAILURE);
} else {
WUST_INFO("number_classifier_trt")
<< "Successfully loaded " << class_names_.size() << " labels from " << label_path;
}
}
bool NumberClassifierTRT::classifyNumber(ArmorObject& armor) {
static thread_local std::unique_ptr<nvinfer1::IExecutionContext> ctx;
if (armor.number_img.empty()) {
return false;
}
if (!ctx) {
auto c = trt_net_->getAContext();
ctx = std::unique_ptr<nvinfer1::IExecutionContext>(c);
WUST_DEBUG("number_classifier_trt") << "Loaded number classifier model for this thread";
if (!ctx) {
WUST_ERROR("number_classifier_trt")
<< "Failed to load thread-local number classifier model.";
return false;
}
}
const cv::Mat image = armor.number_img;
cv::Mat blob;
cv::dnn::blobFromImage(image, blob, 1.0 / 255.0, cv::Size(28, 20));
trt_net_->input2Device(blob.ptr<float>());
void* input_tensor_ptr = trt_net_->getInputTensorPtr();
trt_net_->infer(input_tensor_ptr, ctx.get());
const float* out = static_cast<float*>(trt_net_->output2Host());
cv::Mat outputs(1, 9, CV_32F);
std::memcpy(outputs.data, out, 9 * sizeof(float));
double max_val;
cv::minMaxLoc(outputs, nullptr, &max_val);
cv::Mat prob;
cv::exp(outputs - max_val, prob);
prob /= cv::sum(prob)[0];
double confidence;
cv::Point class_id;
cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id);
const int label_id = class_id.x;
const double raw_conf = armor.confidence;
armor.confidence = confidence;
static const std::map<int, ArmorNumber> label_to_armor_number = {
{ 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 },
{ 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST },
{ 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE }
};
if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) {
armor.number = label_to_armor_number.at(label_id);
return true;
} else {
armor.confidence = raw_conf;
return false;
}
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,39 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base.hpp"
#include "tasks/auto_aim/type.hpp"
#include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp"
namespace wust_vision {
namespace auto_aim {
class NumberClassifierTRT: public NumberClassifierBase {
public:
NumberClassifierTRT(
const std::string& classify_model_path,
const std::string& classify_label_path
);
void initNumberClassifier() override;
bool classifyNumber(ArmorObject& armor) override;
private:
std::vector<std::string> class_names_;
std::string classify_model_path_;
std::string classify_label_path_;
std::unique_ptr<wust_vl::ml_net::TensorRTNet> trt_net_;
nvinfer1::Dims input_dims_;
nvinfer1::Dims output_dims_;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,141 @@
#ifdef USE_ORT
#include "armor_detector_onnxruntime.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/ml_net/onnxruntime/onnxruntime_net.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorOnnxRuntime::Impl {
public:
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
}
std::string model_type = config["onnxruntime"]["model_type"].as<std::string>();
auto model = armor_infer::modeFromString(model_type);
float conf_threshold = config["onnxruntime"]["conf_threshold"].as<float>();
int top_k = config["onnxruntime"]["top_k"].as<int>();
float nms_threshold = config["onnxruntime"]["nms_threshold"].as<float>();
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
std::string provider = config["onnxruntime"]["provider"].as<std::string>("CPU");
provider_ = wust_vl::ml_net::string2OrtProvider(provider);
onnxruntime_net_ = std::make_unique<wust_vl::ml_net::OnnxRuntimeNet>();
wust_vl::ml_net::OnnxRuntimeNet::Params params;
std::string model_path =
utils::expandEnv(config["onnxruntime"]["model_path"].as<std::string>());
params.model_path = model_path;
params.provider = provider_;
onnxruntime_net_->init(params);
}
~Impl() {
onnxruntime_net_.reset();
armor_detect_common_.reset();
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
void
processCallback(const CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
Eigen::Matrix3f transform_matrix;
auto roi = frame.img_frame.src_img(frame.expanded);
cv::Mat resized_img = utils::letterbox(
roi,
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH()
);
const bool swap_rb = armor_infer_->inputRGB()
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
cv::Mat blob = cv::dnn::blobFromImage(
resized_img,
scale,
cv::Size(armor_infer_->inputW(), armor_infer_->inputH()),
cv::Scalar(0, 0, 0),
swap_rb
);
auto output_data = onnxruntime_net_->infer(blob.ptr<float>(), blob.total());
auto output_shape = onnxruntime_net_->getOutputShape();
int rows = static_cast<int>(output_shape[1]);
int cols = static_cast<int>(output_shape[2]);
cv::Mat output_buffer(rows, cols, CV_32F, output_data);
// Parsed variable
std::vector<ArmorObject> objs_result;
objs_result = armor_infer_->postProcess(output_buffer);
std::vector<ArmorObject> armors;
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
}
return;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
frame.id = current_id_++;
processCallback(frame, target_number);
}
wust_vl::ml_net::OrtProvider provider_ = wust_vl::ml_net::OrtProvider::CPU;
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
std::unique_ptr<wust_vl::ml_net::OnnxRuntimeNet> onnxruntime_net_;
};
ArmorDetectorOnnxRuntime::ArmorDetectorOnnxRuntime(
const YAML::Node& config,
bool use_armor_detect_common
) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorOnnxRuntime::~ArmorDetectorOnnxRuntime() {
_impl.reset();
}
void ArmorDetectorOnnxRuntime::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorOnnxRuntime::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,36 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorOnnxRuntime: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorOnnxRuntime>;
explicit ArmorDetectorOnnxRuntime(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorOnnxRuntime>(config, use_armor_detect_common);
}
~ArmorDetectorOnnxRuntime();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,382 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
//
// Additional modifications and features by Chengfu Zou, Labor. Licensed under
// Apache License 2.0.
//
// Copyright (C) FYT Vision Group. All rights reserved.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp"
#include "tasks/auto_aim/armor_detect/number_classifier/number_classifier.hpp"
#include "tasks/utils/utils.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorOpenCV::Impl {
public:
Impl(const YAML::Node& config) {
auto classify_model_path =
utils::expandEnv(config["classify"]["model_path"].as<std::string>());
auto classify_label_path =
utils::expandEnv(config["classify"]["label_path"].as<std::string>());
double classify_threshold = config["classify"]["threshold"].as<double>();
number_classifier_ =
std::make_unique<NumberClassifier>(classify_model_path, classify_label_path);
light_params_.load(config["light"]);
armor_params_.load(config["armor"]);
}
std::vector<ArmorObject> detect(
const cv::Mat& input,
int detect_color,
const std::optional<ArmorNumber>& target_number
) noexcept {
if (input.empty())
return {};
std::vector<Light> lights_;
cv::Mat binary_img, gray_img;
std::tie(binary_img, gray_img) = preprocessImage(input);
lights_ = findLights(input, binary_img);
std::vector<ArmorObject> armors = matchLights(lights_, detect_color);
std::vector<ArmorObject> valid_armors;
for (auto& armor: armors) {
try {
armor.number_img = extractNumber(gray_img, armor);
if (armor.number_img.empty())
continue;
if (!number_classifier_->classifyNumber(armor))
continue;
if (target_number.has_value()) {
if (!isSameTarget(target_number.value(), armor.number)) {
continue;
}
}
if (armor.confidence < classifier_threshold_)
continue;
if (armor.number != ArmorNumber::NO1 && armor.number != ArmorNumber::BASE
&& armor.type == ArmorType::LARGE)
{
continue;
}
valid_armors.push_back(armor);
} catch (const std::exception& e) {
std::cerr << "[detect] Exception: " << e.what() << std::endl;
}
}
return valid_armors;
}
std::tuple<cv::Mat, cv::Mat> preprocessImage(const cv::Mat& img) noexcept {
cv::Mat gray_img;
cv::Mat binary_img;
if (img.empty()) {
return { binary_img, gray_img }; // 空图直接返回空
}
if (img.channels() == 3) {
cv::cvtColor(img, gray_img, cv::COLOR_RGB2GRAY);
} else if (img.channels() == 1) {
cv::cvtColor(img, gray_img, cv::COLOR_BayerRG2GRAY);
} else {
return { binary_img, gray_img };
}
cv::threshold(gray_img, binary_img, light_params_.binary_thres, 255, cv::THRESH_BINARY);
return { binary_img, gray_img };
}
std::vector<Light> findLights(const cv::Mat& img, const cv::Mat& binary_img) noexcept {
std::vector<std::vector<cv::Point>> contours;
contours.reserve(64);
cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
cv::Mat color_img;
if (img.channels() == 3) {
color_img = img;
} else if (img.channels() == 1) {
cv::cvtColor(img, color_img, cv::COLOR_BayerRG2BGR);
} else {
return {};
}
std::vector<Light> lights;
lights.reserve(contours.size());
for (const auto& contour: contours) {
const int n = static_cast<int>(contour.size());
if (n < 6)
continue;
Light light(contour);
if (!isLight(light))
continue;
int sum_r = 0;
int sum_b = 0;
for (const auto& pt: contour) {
const cv::Vec3b& pix = color_img.at<cv::Vec3b>(pt);
sum_r += pix[0];
sum_b += pix[2];
}
const int avg_diff = std::abs(sum_r - sum_b) / n;
if (avg_diff <= light_params_.color_diff_thresh)
continue;
light.color = (sum_r > sum_b) ? 0 : 1;
lights.emplace_back(std::move(light));
}
std::sort(lights.begin(), lights.end(), [](const Light& a, const Light& b) {
return a.center.x < b.center.x;
});
return lights;
}
bool isLight(const Light& light) noexcept {
// width / length 比例
const float ratio = light.width / light.length;
if (ratio <= light_params_.min_ratio || ratio >= light_params_.max_ratio)
return false;
if (light.tilt_angle >= light_params_.max_angle)
return false;
return true;
}
std::vector<ArmorObject>
matchLights(const std::vector<Light>& lights, int detect_color) noexcept {
const int n = static_cast<int>(lights.size());
std::vector<ArmorObject> armors;
armors.reserve(n);
for (int i = 0; i < n; ++i) {
const Light& l1 = lights[i];
if (l1.color != detect_color)
continue;
const float max_dx = l1.length * armor_params_.max_large_center_distance;
for (int j = i + 1; j < n; ++j) {
const Light& l2 = lights[j];
if (l2.color != detect_color)
continue;
const float dx = l2.center.x - l1.center.x;
if (dx > max_dx)
break;
ArmorType type = isArmor(l1, l2);
if (type == ArmorType::INVALID)
continue;
if (containLight(i, j, lights))
continue;
ArmorObject armor(l1, l2);
armor.type = type;
armor.color = (detect_color == 0) ? ArmorColor::RED : ArmorColor::BLUE;
armors.emplace_back(std::move(armor));
}
}
return armors;
}
// Check if there is another light in the boundingRect formed by the 2 lights
bool containLight(const int i, const int j, const std::vector<Light>& lights) noexcept {
const Light& l1 = lights[i];
const Light& l2 = lights[j];
float min_x = std::min({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x });
float max_x = std::max({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x });
float min_y = std::min({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y });
float max_y = std::max({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y });
const float avg_len = 0.5f * (l1.length + l2.length);
const float avg_wid = 0.5f * (l1.width + l2.width);
for (int k = i + 1; k < j; ++k) {
const Light& t = lights[k];
if (t.width > 2.0f * avg_wid)
continue;
if (t.length < 0.5f * avg_len)
continue;
const cv::Point2f& c = t.center;
if (c.x >= min_x && c.x <= max_x && c.y >= min_y && c.y <= max_y) {
return true;
}
}
return false;
}
ArmorType isArmor(const Light& l1, const Light& l2) noexcept {
const float len1 = l1.length;
const float len2 = l2.length;
if (len1 <= 1e-3f || len2 <= 1e-3f)
return ArmorType::INVALID;
const float min_len = (len1 < len2) ? len1 : len2;
const float max_len = (len1 < len2) ? len2 : len1;
if (min_len / max_len <= armor_params_.min_light_ratio)
return ArmorType::INVALID;
const cv::Point2f d = l1.center - l2.center;
const float dist2 = d.dot(d);
const float avg_len = 0.5f * (len1 + len2);
const float min_small = armor_params_.min_small_center_distance * avg_len;
const float max_small = armor_params_.max_small_center_distance * avg_len;
const float min_large = armor_params_.min_large_center_distance * avg_len;
const float max_large = armor_params_.max_large_center_distance * avg_len;
const float min_small2 = min_small * min_small;
const float max_small2 = max_small * max_small;
const float min_large2 = min_large * min_large;
const float max_large2 = max_large * max_large;
const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2);
const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2);
if (!(small_ok || large_ok))
return ArmorType::INVALID;
static const float tan_max_angle = std::tan(armor_params_.max_angle * CV_PI / 180.0f);
if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle)
return ArmorType::INVALID;
float delta_angle = std::fabs(l1.angle - l2.angle);
if (delta_angle > 90.0f)
delta_angle = 180.0f - delta_angle;
if (delta_angle >= light_params_.max_angle_diff)
return ArmorType::INVALID;
return large_ok ? ArmorType::LARGE : ArmorType::SMALL;
}
cv::Mat extractNumber(const cv::Mat& src, const ArmorObject& armor) const noexcept {
constexpr int light_length = 12;
constexpr int warp_height = 28;
constexpr int small_armor_width = 32;
constexpr int large_armor_width = 54;
const cv::Size roi_size(20, 28);
cv::Point2f src_pts[4] = { armor.lights[0].bottom,
armor.lights[0].top,
armor.lights[1].top,
armor.lights[1].bottom };
const int warp_width =
(armor.type == ArmorType::SMALL) ? small_armor_width : large_armor_width;
const int top_y = (warp_height - light_length) / 2 - 1;
const int bottom_y = top_y + light_length;
cv::Point2f dst_pts[4] = {
{ 0.f, static_cast<float>(bottom_y) },
{ 0.f, static_cast<float>(top_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(top_y) },
{ static_cast<float>(warp_width - 1), static_cast<float>(bottom_y) }
};
cv::Mat warp_mat = cv::getPerspectiveTransform(src_pts, dst_pts);
cv::Mat warped;
cv::warpPerspective(
src,
warped,
warp_mat,
cv::Size(warp_width, warp_height),
cv::INTER_LINEAR,
cv::BORDER_CONSTANT,
0
);
const int roi_x = (warp_width - roi_size.width) >> 1;
if (roi_x < 0 || roi_x + roi_size.width > warp_width)
return cv::Mat();
cv::Mat number = warped(cv::Rect(roi_x, 0, roi_size.width, roi_size.height));
cv::threshold(number, number, 0, 255, cv::THRESH_BINARY | cv::THRESH_OTSU);
return number;
}
void setCallback(DetectorCallback callback) {
this->infer_callback_ = callback;
}
void toPts(ArmorObject& armor) {
if (armor.lights.size() != 2) {
armor.is_ok = false;
return;
}
armor.pts[0] = armor.lights[0].top;
armor.pts[1] = armor.lights[0].bottom;
armor.pts[2] = armor.lights[1].bottom;
armor.pts[3] = armor.lights[1].top;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
frame.id = current_id_++;
std::vector<ArmorObject> objs_result;
auto roi = frame.img_frame.src_img(frame.expanded);
objs_result = detect(roi, frame.detect_color, target_number);
if (this->infer_callback_) {
this->infer_callback_(objs_result, frame);
return;
}
return;
}
LightParams light_params_;
ArmorParams armor_params_;
double classifier_threshold_ = 0.5;
std::unique_ptr<NumberClassifier> number_classifier_;
DetectorCallback infer_callback_;
int current_id_ = 0;
};
ArmorDetectorOpenCV::ArmorDetectorOpenCV(const YAML::Node& config) {
_impl = std::make_unique<Impl>(config);
}
ArmorDetectorOpenCV::~ArmorDetectorOpenCV() {
_impl.reset();
}
void ArmorDetectorOpenCV::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorOpenCV::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,42 @@
// Copyright Chen Jun 2023. Licensed under the MIT License.
//
// Additional modifications and features by Chengfu Zou, Labor. Licensed under
// Apache License 2.0.
//
// Copyright (C) FYT Vision Group. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorOpenCV: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorOpenCV>;
explicit ArmorDetectorOpenCV(const YAML::Node& config);
static Ptr create(const YAML::Node& config) {
return std::make_unique<ArmorDetectorOpenCV>(config);
}
~ArmorDetectorOpenCV();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,206 @@
// Copyright 2025 Zikang Xie
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef USE_OPENVINO
#include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/ml_net/openvino/openvino_net.hpp"
namespace wust_vision {
namespace auto_aim {
struct ArmorDetectorOpenVino::Impl {
public:
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
}
std::string model_type = config["openvino"]["model_type"].as<std::string>();
auto model = armor_infer::modeFromString(model_type);
float conf_threshold = config["openvino"]["conf_threshold"].as<float>();
int top_k = config["openvino"]["top_k"].as<int>();
float nms_threshold = config["openvino"]["nms_threshold"].as<float>();
bool use_throughputmode = config["openvino"]["use_throughputmode"].as<bool>();
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
std::string model_path =
utils::expandEnv(config["openvino"]["model_path"].as<std::string>());
auto device_name = config["openvino"]["device_name"].as<std::string>();
ov_params_.model_path = model_path;
ov_params_.device_name = device_name;
ov_params_.mode = use_throughputmode ? ov::hint::PerformanceMode::THROUGHPUT
: ov::hint::PerformanceMode::LATENCY;
initOpenVINO();
}
void
initOpenVINO(wust_vl::video::PixelFormat pixel_format = wust_vl::video::PixelFormat::BGR) {
openvino_net_.reset();
openvino_net_ = std::make_unique<wust_vl::ml_net::OpenvinoNet>();
const auto ppp_init_fun = [this, pixel_format](ov::preprocess::PrePostProcessor& ppp) {
if (pixel_format == wust_vl::video::PixelFormat::RGB) {
ppp.input()
.tensor()
.set_element_type(ov::element::u8)
.set_layout("NHWC")
.set_color_format(ov::preprocess::ColorFormat::RGB);
} else {
ppp.input()
.tensor()
.set_element_type(ov::element::u8)
.set_layout("NHWC")
.set_color_format(ov::preprocess::ColorFormat::BGR);
}
pixel_format_ = pixel_format;
const bool RGB = armor_infer_->inputRGB();
const float scale = armor_infer_->useNorm() ? 255.0f : 1.0f;
if (RGB) {
ppp.input()
.preprocess()
.convert_element_type(ov::element::f32)
.convert_color(ov::preprocess::ColorFormat::RGB)
.scale(scale);
} else {
ppp.input()
.preprocess()
.convert_element_type(ov::element::f32)
.convert_color(ov::preprocess::ColorFormat::BGR)
.scale(scale);
}
ppp.input().model().set_layout("NCHW");
ppp.output().tensor().set_element_type(ov::element::f32);
};
openvino_net_->init(ov_params_, ppp_init_fun);
}
~Impl() {
openvino_net_.reset();
armor_detect_common_.reset();
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
bool processCallback(
const CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) const {
const auto start = std::chrono::steady_clock::now();
Eigen::Matrix3f transform_matrix;
const auto roi = frame.img_frame.src_img(frame.expanded);
cv::Mat resized_img = utils::letterbox(
roi,
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH()
);
const auto input_info = openvino_net_->getInputInfo();
const auto input_tensor =
ov::Tensor(input_info.first, input_info.second, resized_img.data);
const auto output = openvino_net_->infer_thread_local(input_tensor);
// Process output data
const auto output_shape = output.get_shape();
const float* ptr = output.data<const float>();
cv::Mat
output_buffer(output_shape[1], output_shape[2], CV_32F, const_cast<float*>(ptr));
// Parsed variable
auto objs_result = armor_infer_->postProcess(output_buffer);
std::vector<ArmorObject> armors;
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return true;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return true;
}
}
return false;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
if (resetting_) {
return;
}
frame.id = current_id_++;
if (frame.img_frame.pixel_format != pixel_format_) {
resetting_ = true;
initOpenVINO(frame.img_frame.pixel_format);
resetting_ = false;
}
processCallback(frame, target_number);
}
private:
wust_vl::video::PixelFormat pixel_format_ = wust_vl::video::PixelFormat::BGR;
std::unique_ptr<wust_vl::ml_net::OpenvinoNet> openvino_net_;
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
wust_vl::ml_net::OpenvinoNet::Params ov_params_;
bool resetting_ = false;
};
ArmorDetectorOpenVino::ArmorDetectorOpenVino(
const YAML::Node& config,
bool use_armor_detect_common
) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorOpenVino::~ArmorDetectorOpenVino() {
_impl.reset();
}
void ArmorDetectorOpenVino::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorOpenVino::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,38 @@
// Copyright 2023 Yunlong Feng
// Copyright 2025 Lihan Chen
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorOpenVino: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorOpenVino>;
explicit ArmorDetectorOpenVino(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorOpenVino>(config, use_armor_detect_common);
}
~ArmorDetectorOpenVino();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,320 @@
// Copyright 2025 Zikang Xie
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef USE_TRT
#include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp"
#include "cuda_infer/armor_infer.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/common/concurrency/adaptive_resource_pool.hpp"
#include "wust_vl/common/utils/logger.hpp"
#include "wust_vl/common/utils/timer.hpp"
#include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp"
namespace wust_vision {
namespace auto_aim {
static constexpr int MAX_SRC_IMG_W = 1920;
static constexpr int MAX_SRC_IMG_H = 1440;
struct ArmorDetectorTrt::Impl {
public:
struct Infer {
std::unique_ptr<nvinfer1::IExecutionContext> context;
std::unique_ptr<armor_cuda_infer::CudaInfer> cuda_infer;
};
Impl(const YAML::Node& config, bool use_armor_detect_common) {
if (use_armor_detect_common) {
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
}
const double conf_threshold = config["tensorrt"]["conf_threshold"].as<float>();
const double nms_threshold = config["tensorrt"]["nms_threshold"].as<float>();
const int top_k = config["tensorrt"]["top_k"].as<int>();
const int max_infer_running = config["tensorrt"]["max_infer_running"].as<int>();
const double min_free_mem_ratio = config["tensorrt"]["min_free_mem_ratio"].as<double>();
use_cuda_pre_ = config["tensorrt"]["use_cuda_pre"].as<bool>();
log_time_ = config["tensorrt"]["log_time"].as<bool>();
const std::string model_type = config["tensorrt"]["model_type"].as<std::string>();
const std::string model_path =
utils::expandEnv(config["tensorrt"]["model_path"].as<std::string>());
const int device_id = config["tensorrt"]["device_id"].as<int>();
cudaSetDevice(device_id);
const auto model = armor_infer::modeFromString(model_type);
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
model,
conf_threshold,
nms_threshold,
top_k
);
trt_net_ = std::make_unique<wust_vl::ml_net::TensorRTNet>();
wust_vl::ml_net::TensorRTNet::Params trt_params;
trt_params.model_path = model_path;
trt_params.input_dims =
nvinfer1::Dims4 { 1, 3, armor_infer_->inputH(), armor_infer_->inputW() };
trt_net_->init(trt_params);
const auto input_output_dims = trt_net_->getInputOutputDims();
input_dims_ = std::get<0>(input_output_dims);
output_dims_ = std::get<1>(input_output_dims);
wust_vl::common::concurrency::AdaptiveResourcePool<Infer>::Params pool_params;
pool_params.resource_initializer = [&]() {
std::vector<std::unique_ptr<Infer>> infers;
for (int i = 0; i < max_infer_running; ++i) {
auto infer = std::make_unique<Infer>();
auto ctx = trt_net_->getAContext();
infer->context = std::unique_ptr<nvinfer1::IExecutionContext>(ctx);
if (use_cuda_pre_) {
infer->cuda_infer = std::make_unique<armor_cuda_infer::CudaInfer>();
infer->cuda_infer->init(
MAX_SRC_IMG_W,
MAX_SRC_IMG_H,
armor_infer_->inputW(),
armor_infer_->inputH()
);
}
if (!infer->context) {
WUST_ERROR("TRT") << "create infer failed, missing context"
<< " index:" << i;
continue;
}
if (use_cuda_pre_ && !infer->cuda_infer) {
WUST_ERROR("TRT") << "create infer failed, missing cuda_infer"
<< " index:" << i;
continue;
}
size_t free_mem, total_mem;
cudaMemGetInfo(&free_mem, &total_mem);
WUST_DEBUG("TRT") << "Free GPU memory:" << free_mem / 1024.0 / 1024.0 << "MB"
<< "Total GPU memory:" << total_mem / 1024.0 / 1024.0 << "MB";
double free_mem_ratio =
static_cast<double>(free_mem) / static_cast<double>(total_mem);
if (free_mem_ratio < min_free_mem_ratio && i > 0) {
WUST_WARN("TRT") << "GPU memory is not enough!"
<< "Free GPU memory:" << free_mem_ratio * 100 << "%";
WUST_INFO("TRT") << "Cut remaining infer";
break;
}
infers.emplace_back(std::move(infer));
WUST_INFO("TRT") << "create execution context success"
<< "index:" << i;
}
return infers;
};
auto release_func = [&](std::unique_ptr<Infer>& resource) {
if (resource) {
if (resource->cuda_infer) {
resource->cuda_infer.reset();
}
}
};
auto restore_func = [&](size_t idx) -> std::unique_ptr<Infer> {
auto infer = std::make_unique<Infer>();
auto ctx = trt_net_->getAContext();
infer->context = std::unique_ptr<nvinfer1::IExecutionContext>(ctx);
if (use_cuda_pre_) {
infer->cuda_infer = std::make_unique<armor_cuda_infer::CudaInfer>();
infer->cuda_infer->init(
MAX_SRC_IMG_W,
MAX_SRC_IMG_H,
armor_infer_->inputW(),
armor_infer_->inputH()
);
}
if (!infer->context) {
WUST_ERROR("TRT") << "create infer failed, missing context";
return nullptr;
}
if ((use_cuda_pre_) && !infer->cuda_infer) {
WUST_ERROR("TRT") << "create infer failed, missing cuda_infer";
return nullptr;
}
return infer;
};
pool_params.restore_func = restore_func;
pool_params.release_func = release_func;
pool_params.can_restore = [&](size_t active_count) { return false; };
pool_params.should_release = [&](size_t active_count) { return false; };
pool_params.logger = [](const std::string& msg) {
WUST_INFO("ArmorDetectorTrt:infer pool") << msg;
};
infer_pool_ =
std::make_unique<wust_vl::common::concurrency::AdaptiveResourcePool<Infer>>(
pool_params
);
}
~Impl() {
if (infer_pool_) {
infer_pool_.reset();
}
trt_net_.reset();
armor_detect_common_.reset();
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
struct Tag {};
void processCallback(
const CommonFrame& frame,
Infer* infer,
const std::optional<ArmorNumber>& target_number
) const {
std::vector<ArmorObject> armors;
const auto t0 = wust_vl::common::utils::time_utils::now();
Eigen::Matrix3f transform_matrix;
std::vector<ArmorObject> objs_result;
void* input_tensor_ptr;
const cv::Mat roi = frame.img_frame.src_img(frame.expanded);
cv::Mat resized_img;
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
const bool swap_rb = armor_infer_->inputRGB()
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
if (infer->cuda_infer && use_cuda_pre_) {
input_tensor_ptr =
infer->cuda_infer
->preprocess_pitched( //支持不连续内存,无需拷贝后输入可直接传roi的指针
roi.data,
roi.cols,
roi.rows,
roi.step,
scale,
swap_rb,
transform_matrix,
trt_net_->getStream()
);
resized_img = infer->cuda_infer->tensorToMat( //nchw_float_to_hwc_uchar
static_cast<float*>(input_tensor_ptr),
armor_infer_->inputW(),
armor_infer_->inputH(),
scale,
trt_net_->getStream()
);
} else {
resized_img = utils::letterbox(
roi,
transform_matrix,
armor_infer_->inputW(),
armor_infer_->inputH()
);
const cv::Mat blob = cv::dnn::blobFromImage(
resized_img,
scale,
cv::Size(armor_infer_->inputW(), armor_infer_->inputH()),
cv::Scalar(0, 0, 0),
swap_rb
);
trt_net_->input2Device(blob.ptr<float>());
input_tensor_ptr = trt_net_->getInputTensorPtr();
}
const auto t1 = wust_vl::common::utils::time_utils::now();
if (infer->context && input_tensor_ptr) {
trt_net_->infer(input_tensor_ptr, infer->context.get());
}
const auto t2 = wust_vl::common::utils::time_utils::now();
const cv::Mat
output_mat(output_dims_.d[1], output_dims_.d[2], CV_32F, trt_net_->output2Host());
cudaStreamSynchronize(trt_net_->getStream());
objs_result = armor_infer_->postProcess(output_mat);
const auto t3 = wust_vl::common::utils::time_utils::now();
if (log_time_) {
WUST_INFO("TRT") << std::fixed << std::setprecision(3) << "pre "
<< wust_vl::common::utils::time_utils::durationMs(t0, t1) << " "
<< "infer "
<< wust_vl::common::utils::time_utils::durationMs(t1, t2) << " "
<< "post "
<< wust_vl::common::utils::time_utils::durationMs(t2, t3) << " "
<< "total "
<< wust_vl::common::utils::time_utils::durationMs(t0, t3);
}
infer_pool_->release(infer);
if (armor_detect_common_) {
armors = armor_detect_common_->detectNet(
resized_img,
objs_result,
transform_matrix,
frame.detect_color,
target_number
);
// Call callback function
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
} else {
for (auto obj: objs_result) {
auto detect_color = frame.detect_color;
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
continue;
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
continue;
}
obj.transform(transform_matrix);
armors.push_back(obj);
}
if (this->infer_callback_) {
this->infer_callback_(armors, frame);
return;
}
}
return;
}
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
if (infer_pool_) {
auto infer_ptr = infer_pool_->acquire();
if (infer_ptr != nullptr) {
frame.id = current_id_++;
this->processCallback(frame, infer_ptr, target_number);
}
}
}
private:
bool use_cuda_pre_;
bool log_time_;
nvinfer1::Dims input_dims_;
nvinfer1::Dims output_dims_;
DetectorCallback infer_callback_;
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
std::unique_ptr<wust_vl::common::concurrency::AdaptiveResourcePool<Infer>> infer_pool_;
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
int current_id_ = 0;
std::unique_ptr<wust_vl::ml_net::TensorRTNet> trt_net_;
};
ArmorDetectorTrt::ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common) {
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
}
ArmorDetectorTrt::~ArmorDetectorTrt() {
_impl.reset();
}
void ArmorDetectorTrt::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
void ArmorDetectorTrt::pushInput(
CommonFrame& frame,
const std::optional<ArmorNumber>& target_number
) {
_impl->pushInput(frame, target_number);
}
} // namespace auto_aim
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,40 @@
// Copyright 2025 Zikang Xie
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorDetectorTrt: public ArmorDetectorBase {
public:
using Ptr = std::unique_ptr<ArmorDetectorTrt>;
explicit ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common);
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
return std::make_unique<ArmorDetectorTrt>(config, use_armor_detect_common);
}
~ArmorDetectorTrt();
void
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,442 @@
#include "armor_omni.hpp"
#include "3rdparty/angles.h"
#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp"
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/auto_aim/auto_aim_fsm.hpp"
#include "tasks/auto_aim/type.hpp"
#include "wust_vl/common/concurrency/ThreadPool.h"
#include "wust_vl/common/utils/timer.hpp"
#include "wust_vl/video/camera.hpp"
// clang-format off
#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp"
// clang-format on
#include "tasks/auto_aim/armor_tracker/trackerv3.hpp"
#include "tasks/auto_aim/armor_where/armor_where.hpp"
namespace wust_vision::auto_aim {
struct ArmorOmni::Impl {
struct One {
using Ptr = std::shared_ptr<One>;
One(int id) {
self_id = id;
total_score = 0;
}
static Ptr create(int id) {
return std::make_shared<One>(id);
}
void load(
const YAML::Node& config,
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter
) noexcept {
auto yaw_in_big_yaw_deg = config["yaw_in_big_yaw_deg"].as<double>();
yaw_in_big_yaw = yaw_in_big_yaw_deg / 180.0 * M_PI;
camera = std::make_shared<wust_vl::video::Camera>();
camera->init(config);
std::string camera_info_path =
utils::expandEnv(config["camera_info_path"].as<std::string>());
YAML::Node config_camera_info = YAML::LoadFile(camera_info_path);
std::vector<double> camera_k =
config_camera_info["camera_matrix"]["data"].as<std::vector<double>>();
std::vector<double> camera_d =
config_camera_info["distortion_coefficients"]["data"].as<std::vector<double>>();
assert(camera_k.size() == 9);
assert(camera_d.size() == 5);
cv::Mat K(3, 3, CV_64F);
std::memcpy(K.data, camera_k.data(), 9 * sizeof(double));
cv::Mat D(1, 5, CV_64F);
std::memcpy(D.data, camera_d.data(), 5 * sizeof(double));
camera_info = std::make_pair(K.clone(), D.clone());
auto gobal_config = auto_aim_config_parameter->getConfig();
armor_where = ArmorWhere::create(gobal_config["armor_where"], camera_info);
tracker = Tracker::create(auto_aim_config_parameter);
}
void start() noexcept {
if (camera)
camera->start();
}
int self_id;
double total_score;
std::shared_ptr<wust_vl::video::Camera> camera;
ArmorWhere::Ptr armor_where;
Tracker::Ptr tracker;
std::pair<cv::Mat, cv::Mat> camera_info;
double yaw_in_big_yaw;
Target target;
};
struct Obj {
ArmorObject armor;
double score = 0;
One::Ptr one;
std::chrono::steady_clock::time_point timestamp;
Obj(const ArmorObject& a,
double s,
const One::Ptr& o,
std::chrono::steady_clock::time_point ts):
armor(a),
score(s),
one(o),
timestamp(ts) {
if (one)
one->total_score += score;
}
~Obj() {
if (one)
one->total_score -= score;
}
};
static constexpr const char* _ML_CONFIG = "config/omni/detect_ml.yaml";
static constexpr const char* _OPENCV_CONFIG = "config/omni/detect_opencv.yaml";
~Impl() {
run_flag_ = false;
}
Impl(bool detect_color_init, const Ctx& ctx) {
ctx_ = ctx;
detect_color_ = detect_color_init;
config_ = YAML::LoadFile(OMNI_CONFIG);
auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create();
auto_aim_config_parameter_->loadFromFile(OMNI_CONFIG);
auto cameras = config_["cameras"].as<std::vector<std::string>>();
for (size_t i = 0; i < cameras.size(); ++i) {
auto real_path = utils::expandEnv(cameras[i]);
One::Ptr one = One::create(i);
one->load(YAML::LoadFile(real_path), auto_aim_config_parameter_);
ones_.emplace_back(one);
}
auto_aim_config_parameter_->reloadFromOldPath();
fps_ = config_["fps"].as<int>(30);
active_time_ = config_["active_time"].as<double>(0.5);
max_infer_running_ = config_["max_infer_running"].as<int>(0);
min_score_ = config_["min_score"].as<double>();
const std::string armor_detect_backend =
config_["armor_detect_backend"].as<std::string>("");
armor_detector_ = DetectorFactory::createArmorDetector(
armor_detect_backend,
false,
_OPENCV_CONFIG,
_ML_CONFIG
);
armor_detector_->setCallback(std::bind(
&ArmorOmni::Impl::ArmorDetectCallback,
this,
std::placeholders::_1,
std::placeholders::_2
));
thread_pool_ =
std::make_unique<wust_vl::common::concurrency::ThreadPool>(max_infer_running_);
timer_ = std::make_unique<wust_vl::common::utils::Timer>("omni");
latency_averager_ = std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
}
void start() noexcept {
run_flag_ = true;
for (auto& one: ones_) {
one->start();
}
if (timer_) {
const auto timercallback =
std::bind(&ArmorOmni::Impl::timerCallback, this, std::placeholders::_1);
const double rate_hz = fps_;
timer_->start(rate_hz, timercallback);
}
}
int getOneId() const {
static int one_id = 0;
int id = one_id;
one_id = (one_id + 1) % ones_.size();
return id;
}
void timerCallback(double dt_ms) noexcept {
if (!run_flag_ || main_tracking_)
return;
int one_id = getOneId();
auto& one = ones_[one_id];
auto frame = one->camera->readImage();
if (frame.src_img.empty())
return;
CommonFrame common_frame;
common_frame.img_frame = frame;
common_frame.id = one_id;
common_frame.detect_color = detect_color_;
common_frame.expanded = cv::Rect(0, 0, frame.src_img.cols, frame.src_img.rows);
common_frame.offset = cv::Point2f(0, 0);
common_frame.any_ctx = one;
detect(common_frame);
}
void detect(CommonFrame& common_frame) {
if (infer_running_count_ >= max_infer_running_ || !thread_pool_ || !run_flag_) {
return;
}
infer_running_count_++;
if (common_frame.img_frame.src_img.empty()) {
infer_running_count_--;
return;
}
if (armor_detector_) {
armor_detector_->pushInput(common_frame, std::nullopt);
}
infer_running_count_--;
}
void setDetectColor(bool flag) noexcept {
detect_color_ = flag;
}
void updateMainTracking(bool flag) noexcept {
main_tracking_ = flag;
}
int getBestTarget() noexcept {
update();
return best_target_;
}
void
ArmorDetectCallback(const std::vector<ArmorObject>& objs, const CommonFrame& frame) noexcept {
auto one = std::any_cast<One::Ptr>(frame.any_ctx);
std::vector<ArmorObject> sorted_objs;
for (const auto& obj: objs) {
if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) {
continue;
}
sorted_objs.push_back(obj);
std::lock_guard<std::mutex> lock(active_results_mutex_);
active_results_.emplace_back(obj, obj.confidence, one, frame.img_frame.timestamp);
}
update();
Armors armors;
armors.timestamp = frame.img_frame.timestamp;
Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity();
auto& car_b = ctx_.car_motion_buffer;
auto& big_yaw_b = ctx_.big_yaw_motion_buffer;
if (car_b && big_yaw_b) {
const auto t_query = armors.timestamp;
auto apply_motion = [&](const auto& att, const auto& att2) {
R_gimbal2odom =
Eigen::AngleAxisd(
angles::normalize_angle(att2.data.big_yaw + one->yaw_in_big_yaw),
Eigen::Vector3d::UnitZ()
)
* Eigen::AngleAxisd(0.0, Eigen::Vector3d::UnitY())
* Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX());
};
auto car_past_att = car_b->get_interpolated(t_query);
auto big_yaw_past_att = big_yaw_b->get_interpolated(t_query);
if (car_past_att && big_yaw_past_att) {
apply_motion(*car_past_att, *big_yaw_past_att);
} else {
auto last_att = car_b->get_last();
auto last_big_yaw = big_yaw_b->get_last();
if (last_att && last_big_yaw) {
apply_motion(*last_att, *last_big_yaw);
}
}
}
Eigen::Matrix3d R_camera2gimbal;
R_camera2gimbal << 0.0, 0.0, 1.0, -1.0, -0.0, 0.0, 0.0, -1.0, 0.0;
Eigen::Matrix4d T_camera_to_odom = utils::computeCameraToOdomTransform(
R_gimbal2odom,
R_camera2gimbal,
Eigen::Vector3d::Zero()
);
armors.armors = one->armor_where->where(sorted_objs, T_camera_to_odom);
for (auto& armor: armors.armors) {
armor.timestamp = armors.timestamp;
}
auto& target = one->target;
target = one->tracker->track(armors);
const auto now = std::chrono::steady_clock::now();
const auto latency_ms =
wust_vl::common::utils::time_utils::durationMs(frame.img_frame.timestamp, now);
latency_averager_->add(latency_ms);
latency_ms_ = latency_averager_->average();
detect_count_++;
printStats();
}
void update() noexcept {
std::lock_guard<std::mutex> lock(active_results_mutex_);
while (!active_results_.empty()) {
auto& obj = active_results_.front();
if (std::abs(wust_vl::common::utils::time_utils::durationSec(
obj.timestamp,
wust_vl::common::utils::time_utils::now()
))
> active_time_)
{
active_results_.pop_front();
} else {
break;
}
}
if (active_results_.empty()) {
best_target_ = -1;
return;
}
double max_score = min_score_;
best_target_ = -1;
for (size_t i = 0; i < ones_.size(); ++i) {
if (ones_[i]->total_score > max_score) {
max_score = ones_[i]->total_score;
best_target_ = ones_[i]->self_id;
}
}
}
GimbalCmd solve(double bullet_speed) {
GimbalCmd gimbal_cmd;
std::optional<Target> target;
int best_target = getBestTarget();
if (best_target < 0) {
target = std::nullopt;
} else {
target = ones_[best_target]->target;
}
auto& very_aimer = ctx_.very_aimer;
if (!very_aimer) {
return gimbal_cmd;
}
if (target.has_value() && target->checkTargetAppear()) {
try {
gimbal_cmd = very_aimer->veryAim(
target.value(),
bullet_speed,
AutoAimFsm::AIM_WHOLE_CAR_CENTER
);
gimbal_cmd.enable_pitch_diff = 0.0;
gimbal_cmd.enable_yaw_diff = 0.0;
gimbal_cmd.fire_advice = false;
} catch (...) {
WUST_ERROR("omni") << "VeryAim error";
}
} else {
gimbal_cmd.appear = false;
}
return gimbal_cmd;
}
void printStats() {
utils::XSecOnce(
[&] {
WUST_INFO("armor_omni") << "det: " << detect_count_ << " best: " << best_target_
<< " lat: " << latency_ms_;
detect_count_ = 0;
},
1.0
);
}
int fps_;
int max_infer_running_ = 0;
std::atomic<int> infer_running_count_ { 0 };
bool detect_color_;
bool main_tracking_ = false;
bool run_flag_ = false;
double active_time_ = 0;
std::deque<Obj> active_results_;
mutable std::mutex active_results_mutex_;
std::vector<One::Ptr> ones_;
YAML::Node config_;
std::unique_ptr<wust_vl::common::concurrency::ThreadPool> thread_pool_;
std::unique_ptr<wust_vl::common::utils::Timer> timer_;
ArmorDetectorBase::Ptr armor_detector_;
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_;
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
int best_target_ = -1;
int detect_count_ = 0;
double latency_ms_;
double min_score_ = 10.0;
Ctx ctx_;
};
ArmorOmni::ArmorOmni(bool detect_color_init, const Ctx& ctx):
_impl(std::make_unique<Impl>(detect_color_init, ctx)) {}
ArmorOmni::~ArmorOmni() {
_impl.reset();
}
void ArmorOmni::start() {
_impl->start();
}
void ArmorOmni::setDetectColor(bool flag) {
_impl->setDetectColor(flag);
}
void ArmorOmni::updateMainTracking(bool flag) {
_impl->updateMainTracking(flag);
}
int ArmorOmni::getBestTarget() {
return _impl->getBestTarget();
}
GimbalCmd ArmorOmni::solve(double bullet_speed) {
return _impl->solve(bullet_speed);
}
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,36 @@
#pragma once
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
#include "tasks/type_common.hpp"
#include <memory>
#include <wust_vl/common/utils/motion_buffer.hpp>
#include <yaml-cpp/node/node.h>
namespace wust_vision {
namespace auto_aim {
class ArmorOmni {
public:
struct Ctx {
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<CarMotion, 1024>>
car_motion_buffer;
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<BigYaw, 1024>>
big_yaw_motion_buffer;
VeryAimer::Ptr very_aimer;
};
static constexpr const char* OMNI_CONFIG = "config/omni/omni.yaml";
using Ptr = std::unique_ptr<ArmorOmni>;
ArmorOmni(bool detect_color_init, const Ctx& ctx);
static Ptr create(bool detect_color_init, const Ctx& ctx) {
return std::make_unique<ArmorOmni>(detect_color_init, ctx);
}
~ArmorOmni();
void start();
void setDetectColor(bool flag);
void updateMainTracking(bool flag);
int getBestTarget();
GimbalCmd solve(double bullet_speed);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,286 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// ceres
#include <ceres/ceres.h>
#include <opencv2/calib3d.hpp>
#include <opencv2/core/mat.hpp>
// project
#include "KalmanHyLib/kalman_hybird_lib.hpp"
namespace crazy {
enum class MotionModel {
CONSTANT_VELOCITY = 0, // Constant velocity
CONSTANT_ROTATION = 1, // Constant rotation velocity
CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity
};
constexpr int X_N = 11, Z_N = 8;
using VecZ = Eigen::Matrix<double, Z_N, 1>;
using VecX = Eigen::Matrix<double, X_N, 1>;
enum class Mean : uint8_t {
PLBX = 0,
PLBY = 1,
PLTX = 2,
PLTY = 3,
PRTX = 4,
PRTY = 5,
PRBX = 6,
PRBY = 7,
// IMUY = 4,
// IMUP = 5,
// IMUR = 6,
Z_N = 8
};
enum class State : uint8_t {
CX = 0,
VCX = 1,
CY = 2,
VCY = 3,
CZ = 4,
VCZ = 5,
YAW = 6,
VYAW = 7,
R = 8,
L = 9,
H = 10,
outpost01DZ = 9,
outpost02DZ = 10,
X_N = 11
};
struct Predict {
Predict() = default;
explicit Predict(
double dt,
MotionModel model = MotionModel::CONSTANT_VEL_ROT,
double vrx = 0.0,
double vry = 0.0,
double vrz = 0.0
):
dt(dt),
model(model),
vrx(vrx),
vry(vry),
vrz(vrz) {}
template<typename T>
void operator()(const T x0[X_N], T x1[X_N]) const {
for (int i = 0; i < X_N; i++) {
x1[i] = x0[i];
}
// v_xyz
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) {
// linear velocity
x1[(int)State::CX] += x0[(int)State::VCX] * T(dt);
x1[(int)State::CY] += x0[(int)State::VCY] * T(dt);
x1[(int)State::CZ] += x0[(int)State::VCZ] * T(dt);
} else {
// no velocity
x1[(int)State::VCX] *= T(0.);
x1[(int)State::VCY] *= T(0.);
x1[(int)State::VCZ] *= T(0.);
}
x1[(int)State::CX] -= T(vrx) * T(dt);
x1[(int)State::CY] -= T(vry) * T(dt);
x1[(int)State::CZ] -= T(vrz) * T(dt);
x1[(int)State::VCZ] *= T(0.);
// v_yaw
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) {
// angular velocity
x1[(int)State::YAW] += x0[(int)State::VYAW] * T(dt);
} else {
// no rotation
x1[(int)State::VYAW] *= T(0.);
}
}
double dt;
MotionModel model;
double vrx, vry, vrz;
};
constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135
constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0;
constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180;
template<typename T>
T normalize_angle_t(T angle) {
T two_pi = T(2.0 * M_PI);
return angle - two_pi * floor((angle + T(M_PI)) / two_pi);
}
template<typename T>
Eigen::Quaternion<T>
eulerToQuat(const Eigen::Vector<T, 3>& euler, int axis0, int axis1, int axis2, bool extrinsic) {
T rz = euler[0];
T ry = euler[1];
T rx = euler[2];
Eigen::Quaternion<T> qx(Eigen::AngleAxis<T>(rx, Eigen::Vector3<T>::UnitX()));
Eigen::Quaternion<T> qy(Eigen::AngleAxis<T>(ry, Eigen::Vector3<T>::UnitY()));
Eigen::Quaternion<T> qz(Eigen::AngleAxis<T>(rz, Eigen::Vector3<T>::UnitZ()));
if (!extrinsic)
std::swap(axis0, axis2);
Eigen::Quaternion<T> q;
if (axis0 == 0 && axis1 == 1 && axis2 == 2)
q = qx * qy * qz;
else if (axis0 == 0 && axis1 == 2 && axis2 == 1)
q = qx * qz * qy;
else if (axis0 == 1 && axis1 == 0 && axis2 == 2)
q = qy * qx * qz;
else if (axis0 == 1 && axis1 == 2 && axis2 == 0)
q = qy * qz * qx;
else if (axis0 == 2 && axis1 == 0 && axis2 == 1)
q = qz * qx * qy;
else if (axis0 == 2 && axis1 == 1 && axis2 == 0)
q = qz * qy * qx;
else
throw std::invalid_argument("Unsupported axis order");
return q;
}
template<typename PointType, typename U>
std::vector<PointType> buildObjectPoints(const U& w, const U& h) noexcept {
using T = U;
T half2 = T(2.0);
return { PointType(T(0), w / half2, -h / half2),
PointType(T(0), w / half2, h / half2),
PointType(T(0), -w / half2, h / half2),
PointType(T(0), -w / half2, -h / half2) };
}
struct Measure {
struct MeasureCtx {
int armor_num = 4;
int id = 0;
Eigen::Matrix4d T_odom_to_camera_d;
cv::Mat camera_intrinsic;
cv::Mat camera_distortion;
bool is_big;
} ctx;
Measure() = default;
explicit Measure(const MeasureCtx& c): ctx(c) {}
template<typename T>
void operator()(const T x[X_N], T z[Z_N]) const {
T id_t = T(ctx.id);
T num_t = T(ctx.armor_num);
T two = T(2.0);
T angle = normalize_angle_t(x[(int)State::YAW] + id_t * two * T(M_PI) / num_t);
bool outpost = (ctx.armor_num == 3);
bool use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3);
T r = use_l_h ? x[(int)State::R] + x[(int)State::L] : x[(int)State::R];
T armor_x = x[(int)State::CX] - ceres::cos(angle) * r;
T armor_y = x[(int)State::CY] - ceres::sin(angle) * r;
T armor_z = outpost ? getoutpost_armor_z(x)
: use_l_h ? x[(int)State::CZ] + x[(int)State::H]
: x[(int)State::CZ];
Eigen::Vector3<T> euler_odom;
euler_odom[0] = angle; //yaw
euler_odom[1] = outpost ? T(-FIFTTEN_DEGREE_RAD) : T(FIFTTEN_DEGREE_RAD); //pitch
euler_odom[2] = T(M_PI / 2.0); //roll
Eigen::Quaternion<T> q_odom = eulerToQuat(euler_odom, 2, 1, 0, true);
Eigen::Matrix4<T> T_odom_to_camera = ctx.T_odom_to_camera_d.cast<T>();
Eigen::Vector4<T> pos_odom4(armor_x, armor_y, armor_z, T(1.0));
Eigen::Vector4<T> pos_camera4 = T_odom_to_camera * pos_odom4;
Eigen::Vector3<T> pos_camera = pos_camera4.template head<3>();
Eigen::Matrix3<T> R_odom_to_camera = T_odom_to_camera.block(0, 0, 3, 3).template cast<T>();
Eigen::Matrix3<T> R_ori_odom = q_odom.normalized().toRotationMatrix();
Eigen::Matrix3<T> R_camera = R_odom_to_camera * R_ori_odom;
Eigen::Quaternion<T> q_camera(R_camera);
q_camera.normalize();
T w3 = ctx.is_big ? T(LARGE_ARMOR_WIDTH) : T(SMALL_ARMOR_WIDTH);
T h3 = ctx.is_big ? T(LARGE_ARMOR_HEIGHT) : T(SMALL_ARMOR_HEIGHT);
auto objPts = buildObjectPoints<Eigen::Matrix<T, 3, 1>>(w3, h3);
Eigen::Matrix3<T> R = q_camera.toRotationMatrix();
Eigen::Matrix<T, 3, 1> t = pos_camera;
std::vector<Eigen::Matrix<T, 3, 1>> Pc;
Pc.reserve(objPts.size());
for (const auto& p: objPts) {
Eigen::Matrix<T, 3, 1> v = p;
Pc.push_back(R * v + t);
}
const cv::Mat& K = ctx.camera_intrinsic;
T fx = T(K.at<double>(0, 0));
T fy = T(K.at<double>(1, 1));
T cx = T(K.at<double>(0, 2));
T cy = T(K.at<double>(1, 2));
std::array<T, 4> u, v;
for (int i = 0; i < 4; i++) {
T Xc = Pc[i][0];
T Yc = Pc[i][1];
T Zc = Pc[i][2];
u[i] = fx * (Xc / Zc) + cx;
v[i] = fy * (Yc / Zc) + cy;
}
z[0] = u[0];
z[1] = v[0];
z[2] = u[1];
z[3] = v[1];
z[4] = u[2];
z[5] = v[2];
z[6] = u[3];
z[7] = v[3];
}
template<typename T>
T getoutpost_armor_z(const T x[X_N]) const {
if (ctx.id == 0)
return x[(int)State::CZ];
if (ctx.id == 1)
return x[(int)State::CZ] + x[(int)State::outpost01DZ];
if (ctx.id == 2)
return x[(int)State::CZ] + x[(int)State::outpost02DZ];
return x[(int)State::CZ];
}
using VecX = Eigen::Matrix<double, X_N, 1>;
using VecZ = Eigen::Matrix<double, Z_N, 1>;
void h(const VecX& x, VecZ& z) const {
operator()(x.data(), z.data());
}
};
using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
} // namespace crazy

View File

@@ -0,0 +1,226 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// ceres
#include <algorithm>
#include <ceres/ceres.h>
// project
#include "KalmanHyLib/kalman_hybird_lib.hpp"
namespace ypdv2armor_motion_model {
enum class MotionModel {
CONSTANT_VELOCITY = 0, // Constant velocity
CONSTANT_ROTATION = 1, // Constant rotation velocity
CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity
};
// X_N: state dimension, Z_N: measurement dimension
constexpr int X_N = 11, Z_N = 4;
using VecZ = Eigen::Matrix<double, Z_N, 1>;
using VecX = Eigen::Matrix<double, X_N, 1>;
enum class MeasureID : uint8_t { YPD_Y = 0, YPD_P = 1, YPD_D = 2, ORI_YAW = 3, Z_N = 4 };
enum class StateID : uint8_t {
CX = 0,
VCX = 1,
CY = 2,
VCY = 3,
CZ = 4,
VCZ = 5,
YAW = 6,
VYAW = 7,
R = 8,
L = 9,
H = 10,
outpost01DZ = 9,
outpost02DZ = 10,
X_N = 11
};
struct State {
VecX x;
[[nodiscard]] inline double cx() const noexcept {
return x((int)StateID::CX);
}
[[nodiscard]] inline double cy() const noexcept {
return x((int)StateID::CY);
}
[[nodiscard]] inline double cz() const noexcept {
return x((int)StateID::CZ);
}
[[nodiscard]] inline Eigen::Vector3d pos() const noexcept {
return Eigen::Vector3d(cx(), cy(), cz());
}
[[nodiscard]] inline double vcx() const noexcept {
return x((int)StateID::VCX);
}
[[nodiscard]] inline double vcy() const noexcept {
return x((int)StateID::VCY);
}
[[nodiscard]] inline double vcz() const noexcept {
return x((int)StateID::VCZ);
}
[[nodiscard]] inline Eigen::Vector3d vel() const noexcept {
return Eigen::Vector3d(vcx(), vcy(), vcz());
}
[[nodiscard]] inline double vyaw() const noexcept {
return x((int)StateID::VYAW);
}
[[nodiscard]] inline double yaw() const noexcept {
return x((int)StateID::YAW);
}
[[nodiscard]] inline double r() const noexcept {
return x((int)StateID::R);
}
[[nodiscard]] inline double l() const noexcept {
return x((int)StateID::L);
}
[[nodiscard]] inline double h() const noexcept {
return x((int)StateID::H);
}
[[nodiscard]] inline double outpost01DZ() const noexcept {
return x((int)StateID::outpost01DZ);
}
[[nodiscard]] inline double outpost02DZ() const noexcept {
return x((int)StateID::outpost02DZ);
}
};
struct Predict {
Predict() = default;
explicit Predict(double dt, MotionModel model = MotionModel::CONSTANT_VEL_ROT):
dt(dt),
model(model) {}
template<typename T>
void operator()(const T x0[X_N], T x1[X_N]) const {
for (int i = 0; i < X_N; i++) {
x1[i] = x0[i];
}
// v_xyz
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) {
// linear velocity
x1[(int)StateID::CX] += x0[(int)StateID::VCX] * T(dt);
x1[(int)StateID::CY] += x0[(int)StateID::VCY] * T(dt);
x1[(int)StateID::CZ] += x0[(int)StateID::VCZ] * T(dt);
} else {
// no velocity
x1[(int)StateID::VCX] *= T(0.);
x1[(int)StateID::VCY] *= T(0.);
x1[(int)StateID::VCZ] *= T(0.);
}
// v_yaw
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) {
// angular velocity
x1[(int)StateID::YAW] += x0[(int)StateID::VYAW] * T(dt);
} else {
// no rotation
x1[(int)StateID::VYAW] *= T(0.);
}
clampState(x1);
}
template<typename T>
void clampState(T x1[X_N]) const {
auto& r = x1[(int)StateID::R];
auto& l = x1[(int)StateID::L];
r = std::clamp(r, T(0.1), T(0.5));
if (r < T(0.1) || r > T(0.5)) {
r = T(0.25);
l = T(0);
}
T sum = r + l;
if (sum < T(0.1) || sum > T(0.5)) {
r = T(0.25);
l = T(0);
}
auto& h = x1[(int)StateID::H];
h = std::clamp(h, T(-0.5), T(0.5));
}
void f(const VecX& x0, VecX& x1) const {
assert(x0.size() == X_N);
assert(x1.size() == X_N);
operator()(x0.data(), x1.data());
}
double dt;
MotionModel model;
};
template<typename T>
T normalize_angle_t(T angle) {
T two_pi = T(2.0 * M_PI);
return angle - two_pi * floor((angle + T(M_PI)) / two_pi);
}
struct Measure {
struct MeasureCtx {
MeasureCtx() = default;
MeasureCtx(int id, int armor_num): armor_num(armor_num), id(id) {}
int armor_num = 4;
int id = 0;
} ctx;
Measure() = default;
explicit Measure(MeasureCtx c): ctx(c) {}
template<typename T>
void operator()(const T x[X_N], T z[Z_N]) const {
// Compute armor position
auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x);
T xy_dist = ceres::sqrt(armor_x * armor_x + armor_y * armor_y);
T dist = ceres::sqrt(xy_dist * xy_dist + armor_z * armor_z);
// Observation model
z[(int)MeasureID::YPD_Y] = ceres::atan2(armor_y, armor_x); // yaw
z[(int)MeasureID::YPD_P] = ceres::atan2(armor_z, xy_dist); // pitch
z[(int)MeasureID::YPD_D] = dist; // distance
z[(int)MeasureID::ORI_YAW] = angle; // orientation_yaw
}
template<typename T>
T get_angle(const T x[X_N]) const {
return normalize_angle_t(x[(int)StateID::YAW] + ctx.id * 2 * M_PI / ctx.armor_num);
}
template<typename T>
std::tuple<T, T, T, T> h_armor_xyza(const T x[X_N]) const {
T angle = get_angle(x);
auto outpost = ctx.armor_num == 3;
auto use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3);
T r = (use_l_h) ? x[(int)StateID::R] + x[(int)StateID::L] : x[(int)StateID::R];
T armor_x = x[(int)StateID::CX] - ceres::cos(angle) * r;
T armor_y = x[(int)StateID::CY] - ceres::sin(angle) * r;
T armor_z = (outpost) ? getoutpost_armor_z(x)
: (use_l_h) ? x[(int)StateID::CZ] + x[(int)StateID::H]
: x[(int)StateID::CZ];
return { armor_x, armor_y, armor_z, angle };
}
Eigen::Vector4d h_armor_xyza(const VecX& x) const {
assert(x.size() == X_N);
auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x.data());
return { armor_x, armor_y, armor_z, angle };
}
template<typename T>
T getoutpost_armor_z(const T x[X_N]) const {
return (ctx.id == 0) ? x[(int)StateID::CZ]
: (ctx.id == 1) ? x[(int)StateID::CZ] + x[(int)StateID::outpost01DZ]
: (ctx.id == 2) ? x[(int)StateID::CZ] + x[(int)StateID::outpost02DZ]
: x[(int)StateID::CZ];
}
void h(const VecX& x, VecZ& z) const {
assert(x.size() == X_N);
assert(z.size() == Z_N);
operator()(x.data(), z.data());
}
};
using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
} // namespace ypdv2armor_motion_model

View File

@@ -0,0 +1,379 @@
#include "target.hpp"
namespace wust_vision {
namespace auto_aim {
Target::Target() {
target_state_.x = Eigen::VectorXd::Zero(MModel::X_N);
}
Target::Target(const Armor& a, TargetConfig::Ptr target_config) {
Eigen::DiagonalMatrix<double, ypdv2armor_motion_model::X_N> p0;
if (a.number == ArmorNumber::OUTPOST) {
p0.diagonal() << 1, 64, 1, 64, 1, 81, 0.4, 100, 1e-4, 0.1, 0.1;
armor_num_ = 3;
radius_pre_ = 0.2765;
} else if (a.number == ArmorNumber::BASE) {
p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1e-4, 0, 0;
armor_num_ = 3;
radius_pre_ = 0.3205;
} else {
p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1, 1, 1;
armor_num_ = 4;
radius_pre_ = 0.2;
}
target_config_ = target_config;
const auto yfv2 = MModel::Predict(0.005);
ctx_.armor_num = armor_num_;
ctx_.id = 0;
const auto yhv2 = MModel::Measure(ctx_);
const auto yu_qv2 = [this]() {
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
return q;
};
const auto yu_rv2 = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
return r;
};
esekf_ypd_ = MModel::RobotStateESEKF(yfv2, yhv2, yu_qv2, yu_rv2, p0);
esekf_ypd_.setResidualFunc([this](
const Eigen::Matrix<double, MModel::Z_N, 1>& z_pred,
const Eigen::Matrix<double, MModel::Z_N, 1>& z
) {
Eigen::Matrix<double, MModel::Z_N, 1> r = z - z_pred;
r[0] = angles::shortest_angular_distance(
z_pred[(int)MModel::MeasureID::YPD_Y],
z[(int)MModel::MeasureID::YPD_Y]
); // yaw
r[3] = angles::shortest_angular_distance(
z_pred[(int)MModel::MeasureID::ORI_YAW],
z[(int)MModel::MeasureID::ORI_YAW]
); // ori_yaw
return r;
});
esekf_ypd_.setIterationNum(target_config_->esekf_iter_num_param.get());
esekf_ypd_.setInjectFunc([this](
const Eigen::Matrix<double, MModel::X_N, 1>& delta,
Eigen::Matrix<double, MModel::X_N, 1>& nominal
) {
for (int i = 0; i < MModel::X_N; i++) {
if (i == (int)MModel::StateID::YAW)
continue;
nominal[i] += delta[i];
}
nominal[(int)MModel::StateID::YAW] = angles::normalize_angle(
nominal[(int)MModel::StateID::YAW] + delta[(int)MModel::StateID::YAW]
);
});
const double xa = a.target_pos.x();
const double ya = a.target_pos.y();
const double za = a.target_pos.z();
const double yaw = utils::orientationToYaw<Target>(a.target_ori);
target_state_.x = Eigen::VectorXd::Zero(MModel::X_N);
const double r = radius_pre_;
const double xc = xa + r * cos(yaw);
const double yc = ya + r * sin(yaw);
const double zc = za;
target_state_.x << xc, 0, yc, 0, zc, 0, yaw, 0, r, 0, 0;
esekf_ypd_.setState(target_state_.x);
tracked_id_ = a.number;
type_ = a.type;
last_t_ = a.timestamp;
timestamp_ = a.timestamp;
is_inited = true;
}
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
Target::computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z
) const noexcept {
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
const double delta_angle = angles::normalize_angle(z[3] - z[0]);
const double abs_delta = std::abs(delta_angle);
// sin插值函数小值慢、大值快
const auto sinInterp = [](double x, double x0, double x1, double y0, double y1) -> double {
double t = (x - x0) / (x1 - x0);
if (t < 0)
t = 0;
if (t > 1)
t = 1;
double s = std::sin(t * M_PI / 2.0);
return y0 + s * (y1 - y0);
};
// clang-format off
r <<target_config_->yp_r_param.get(), 0, 0, 0,
0, target_config_->yp_r_param.get() , 0, 0,
0, 0, sinInterp(abs_delta, 0.0, M_PI/2.0, target_config_->dis_r_front_param.get(), target_config_->dis_r_side_param.get())+z[2]*z[2]*target_config_->dis2_r_ratio_param.get(), 0,
0, 0, 0,log(std::abs(z[2]) + 1) *target_config_->yaw_r_log_ratio_param.get() + sinInterp(M_PI/2.0-abs_delta, 0.0, M_PI/2.0, target_config_->yaw_r_base_side_param.get(), target_config_->yaw_r_base_front_param.get());
// clang-format on
return r;
}
Eigen::Matrix<double, MModel::X_N, MModel::X_N> Target::computeProcessNoise(double dt
) const noexcept {
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
Eigen::Vector3d q_xyz;
double q_yaw;
double q_l, q_h;
if (tracked_id_ == ArmorNumber::OUTPOST) {
q_xyz = target_config_->qxyz_output; // 前哨站加速度方差
q_yaw = target_config_->qyaw_output_param.get(); // 前哨站角加速度方差
q_l = target_config_->q_outpost_dz_param.get();
q_h = target_config_->q_outpost_dz_param.get();
} else {
q_xyz = target_config_->qxyz_common; // 加速度方差
q_yaw = target_config_->qyaw_common_param.get(); // 角加速度方差
q_l = target_config_->q_l_param.get();
q_h = target_config_->q_h_param.get();
}
const double t = dt;
const double q_x_x = pow(t, 4) / 4 * q_xyz.x(), q_x_vx = pow(t, 3) / 2 * q_xyz.x(),
q_vx_vx = pow(t, 2) * q_xyz.x();
const double q_y_y = pow(t, 4) / 4 * q_xyz.y(), q_y_vy = pow(t, 3) / 2 * q_xyz.y(),
q_vy_vy = pow(t, 2) * q_xyz.y();
const double q_z_z = pow(t, 4) / 4 * q_xyz.z(), q_z_vz = pow(t, 3) / 2 * q_xyz.z(),
q_vz_vz = pow(t, 2) * q_xyz.z();
const double q_yaw_yaw = pow(t, 4) / 4 * q_yaw, q_yaw_vyaw = pow(t, 3) / 2 * q_yaw,
q_vyaw_vyaw = pow(t, 2) * q_yaw;
const double q_r = target_config_->q_r_param.get();
// clang-format off
// xc v_xc yc v_yc zc v_zc yaw v_yaw r l h
q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0, 0,
q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0, 0,
0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0, 0,
0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0, 0,
0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, q_r, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, q_l, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, q_h;
// clang-format on
return q;
}
MModel::Predict Target::getPredictFunc(double dt) const noexcept {
MModel::Predict predict_func;
if (tracked_id_ == ArmorNumber::OUTPOST) {
predict_func = MModel::Predict {
dt,
MModel::MotionModel::CONSTANT_ROTATION,
};
} else {
predict_func = MModel::Predict {
dt,
MModel::MotionModel::CONSTANT_VEL_ROT,
};
}
return predict_func;
}
void Target::predict(std::chrono::steady_clock::time_point t) noexcept {
const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
predict(dt);
last_t_ = t;
}
void Target::predict(double dt) noexcept {
MModel::Predict predict_func = getPredictFunc(dt);
esekf_ypd_.setPredictFunc(predict_func);
const auto yu_qv2 = [dt, this]() { return computeProcessNoise(dt); };
esekf_ypd_.setUpdateQ(yu_qv2);
target_state_.x = esekf_ypd_.predict();
if (target_state_.pos().norm() < 0.5) {
is_tracking = false;
}
}
void Target::predictSimple(std::chrono::steady_clock::time_point t) noexcept {
const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
predictSimple(dt);
last_t_ = t;
}
void Target::predictSimple(double dt) noexcept {
MModel::Predict predict_func = getPredictFunc(dt);
predict_func.f(target_state_.x, target_state_.x);
if (target_state_.pos().norm() < 0.5) {
is_tracking = false;
}
}
bool Target::update(const std::pair<int, Armor>& a) noexcept {
const auto armor = a.second;
const auto id = a.first;
const auto yu_rv2 = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
return this->computeMeasurementCovariance(z);
};
esekf_ypd_.setUpdateR(yu_rv2);
measurement_ = getMeasure(armor);
if (id != 0)
jumped = true;
ctx_.id = id;
esekf_ypd_.setMeasureFunc(MModel::Measure { ctx_ });
target_state_.x = esekf_ypd_.update(measurement_);
timestamp_ = armor.timestamp;
last_t_ = timestamp_;
return true;
}
cv::Rect Target::expanded(
Eigen::Matrix4d T_camera_to_odom,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const cv::Size& image_size
) const noexcept {
const double dt = wust_vl::common::utils::time_utils::durationSec(
timestamp_,
wust_vl::common::utils::time_utils::now()
);
if (!is_inited || dt > target_config_->lost_time_thres_param.get()) {
return cv::Rect(0, 0, 0, 0);
}
const float car_box_half =
std::max(target_state_.r(), target_state_.r() + target_state_.l()) + 0.15;
static std::vector<cv::Point3f> CAR_BOX;
CAR_BOX = { { 0, car_box_half, -car_box_half },
{ 0, -car_box_half, -car_box_half },
{ 0, -car_box_half, car_box_half },
{ 0, car_box_half, car_box_half } };
const Eigen::Matrix4d T_odom_to_camera = T_camera_to_odom.inverse();
const Eigen::Vector4d
pos_odom(target_state_.cx(), target_state_.cy(), target_state_.cz(), 1.0);
const Eigen::Vector4d pos_cam = T_odom_to_camera * pos_odom;
if (pos_cam.z() <= 0.2) {
return cv::Rect(0, 0, 0, 0);
}
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << pos_cam.x(), pos_cam.y(), pos_cam.z());
Eigen::Vector3d euler;
euler.z() = M_PI / 2.0;
euler.y() = 0;
euler.x() = std::atan2(pos_odom.y(), pos_odom.x());
const Eigen::Quaterniond ori = utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
const auto target_ori = utils::transformOrientation(ori, T_odom_to_camera);
const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix();
const cv::Mat rot_mat =
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
tf_rot(0, 1),
tf_rot(0, 2),
tf_rot(1, 0),
tf_rot(1, 1),
tf_rot(1, 2),
tf_rot(2, 0),
tf_rot(2, 1),
tf_rot(2, 2));
cv::Mat rvec;
cv::Rodrigues(rot_mat, rvec);
std::vector<cv::Point2f> pts_2d;
cv::projectPoints(CAR_BOX, rvec, tvec, camera_intrinsic, camera_distortion, pts_2d);
const cv::Rect rect = cv::boundingRect(pts_2d);
const cv::Rect img_rect(0, 0, image_size.width, image_size.height);
if ((rect & img_rect).area() <= 0) {
return cv::Rect(0, 0, 0, 0);
}
const int base_side = std::max(rect.width, rect.height);
const int max_side = std::max(image_size.width, image_size.height);
const double lost_dt = target_config_->lost_time_thres_param.get();
const double dt_clamped = std::max(0.0, std::min(dt, lost_dt));
int side = static_cast<int>(base_side + (max_side - base_side) * (dt_clamped / lost_dt));
if (dt >= lost_dt) {
side = max_side;
}
const int cx = rect.x + rect.width / 2;
const int cy = rect.y + rect.height / 2;
cv::Rect square(cx - side / 2, cy - side / 2, side, side);
square &= img_rect;
return square;
}
std::vector<std::pair<int, Armor>> Target::match(const std::vector<Armor>& armors) noexcept {
std::vector<std::pair<int, Armor>> result;
const int n_obs = static_cast<int>(armors.size());
const int armors_num = armor_num_;
const double GATE = target_config_->match_gate_param.get();
const double max_cost = 1e9;
std::vector<std::vector<double>> cost(n_obs, std::vector<double>(armors_num, max_cost + 1));
std::vector<MModel::VecZ> meas_list(n_obs);
for (int j = 0; j < n_obs; ++j) {
meas_list[j] = getMeasure(armors[j]);
}
for (int j = 0; j < n_obs; ++j) {
for (int id = 0; id < armors_num; ++id) {
MModel::Measure::MeasureCtx tmp_ctx(id, armors_num);
MModel::Measure measure(tmp_ctx);
MModel::VecZ z_pred;
measure.h(target_state_.x, z_pred);
MModel::VecZ nu = meas_list[j] - z_pred;
nu[(int)MModel::MeasureID::YPD_Y] =
angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_Y]);
nu[(int)MModel::MeasureID::YPD_P] =
angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_P]);
nu[(int)MModel::MeasureID::ORI_YAW] =
angles::normalize_angle(nu[(int)MModel::MeasureID::ORI_YAW]);
auto R = computeMeasurementCovariance(z_pred);
double d2 = nu.transpose() * R.ldlt().solve(nu);
// 门控
if (std::isfinite(d2) && d2 < GATE) {
cost[j][id] = d2;
}
}
}
std::vector<bool> used_obs(n_obs, false);
std::vector<bool> used_id(armors_num, false);
while (true) {
double best = max_cost;
int best_j = -1;
int best_id = -1;
for (int j = 0; j < n_obs; ++j) {
if (used_obs[j])
continue;
for (int id = 0; id < armors_num; ++id) {
if (used_id[id])
continue;
if (cost[j][id] < best) {
best = cost[j][id];
best_j = j;
best_id = id;
}
}
}
if (best_j < 0 || best_id < 0) {
break;
}
used_obs[best_j] = true;
used_id[best_id] = true;
result.push_back(std::make_pair(best_id, armors[best_j]));
}
return result;
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,185 @@
#pragma once
#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp"
#include "tasks/auto_aim/type.hpp"
#include "wust_vl/common/utils/parameter.hpp"
#include <wust_vl/common/utils/timer.hpp>
namespace wust_vision {
namespace auto_aim {
namespace MModel = ypdv2armor_motion_model;
struct TargetConfig: wust_vl::common::utils::ParamGroup {
static constexpr const char* kKey = "armor_tracker";
const char* key() const override {
return kKey;
}
GEN_PARAM(int, esekf_iter_num);
GEN_PARAM(double, lost_time_thres);
GEN_PARAM(int, tracking_thres);
GEN_PARAM(double, max_yaw_diff_deg);
GEN_PARAM(double, max_dis_diff);
GEN_PARAM(double, match_gate);
GEN_PARAM(double, qyaw_common);
GEN_PARAM(double, qyaw_output);
GEN_PARAM(double, q_r);
GEN_PARAM(double, q_l);
GEN_PARAM(double, q_h);
GEN_PARAM(double, q_outpost_dz);
GEN_PARAM(double, yp_r);
GEN_PARAM(double, dis_r_front);
GEN_PARAM(double, dis_r_side);
GEN_PARAM(double, dis2_r_ratio);
GEN_PARAM(double, yaw_r_base_front);
GEN_PARAM(double, yaw_r_base_side);
GEN_PARAM(double, yaw_r_log_ratio);
GEN_PARAM(std::vector<double>, qxyz_common);
GEN_PARAM(std::vector<double>, qxyz_output);
Eigen::Vector3d qxyz_common = { 100, 100, 100 };
Eigen::Vector3d qxyz_output = { 10, 10, 10 };
using Ptr = std::shared_ptr<TargetConfig>;
TargetConfig() {
qxyz_output_param.onChange([this](auto o, auto n) {
qxyz_common = Eigen::Vector3d(n[0], n[1], n[2]);
});
qxyz_output_param.onChange([this](auto o, auto n) {
qxyz_output = Eigen::Vector3d(n[0], n[1], n[2]);
});
}
static Ptr create() {
return std::make_shared<TargetConfig>();
}
void loadSelf(const YAML::Node& node) override {
esekf_iter_num_param.load(node);
lost_time_thres_param.load(node);
tracking_thres_param.load(node);
max_yaw_diff_deg_param.load(node);
max_dis_diff_param.load(node);
match_gate_param.load(node);
qyaw_common_param.load(node);
qyaw_output_param.load(node);
qxyz_common_param.load(node);
qxyz_output_param.load(node);
q_r_param.load(node);
q_l_param.load(node);
q_h_param.load(node);
q_outpost_dz_param.load(node);
yp_r_param.load(node);
dis_r_front_param.load(node);
dis_r_side_param.load(node);
yaw_r_base_front_param.load(node);
yaw_r_base_side_param.load(node);
yaw_r_log_ratio_param.load(node);
}
};
class Target {
public:
Target();
Target(const Armor& armor, TargetConfig::Ptr target_config);
MModel::Measure::MeasureCtx ctx_;
ArmorNumber tracked_id_;
std::string type_;
MModel::VecZ measurement_ = Eigen::Matrix<double, MModel::Z_N, 1>::Zero();
MModel::State target_state_ = MModel::State();
double radius_pre_;
int armor_num_ = 4;
bool jumped = false;
bool is_inited = false;
bool is_tracking = false;
std::chrono::steady_clock::time_point last_t_;
std::chrono::steady_clock::time_point timestamp_;
MModel::RobotStateESEKF esekf_ypd_;
TargetConfig::Ptr target_config_;
[[nodiscard]] cv::Rect expanded(
Eigen::Matrix4d T_camera_to_odom,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const cv::Size& image_size
) const noexcept;
void predict(std::chrono::steady_clock::time_point t) noexcept;
void predict(double dt) noexcept;
void predictSimple(std::chrono::steady_clock::time_point t) noexcept;
void predictSimple(double dt) noexcept;
[[nodiscard]] MModel::Predict getPredictFunc(double dt) const noexcept;
bool update(const std::pair<int, Armor>& armor) noexcept;
[[nodiscard]] Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z) const noexcept;
[[nodiscard]] Eigen::Matrix<double, MModel::X_N, MModel::X_N> computeProcessNoise(double dt
) const noexcept;
[[nodiscard]] std::optional<ArmorNumber> getArmorNumber() const noexcept {
if (!checkTargetAppear()) {
return std::nullopt;
}
return tracked_id_;
}
[[nodiscard]] std::vector<double> getArmorYaws() const noexcept {
std::vector<double> yaw_list;
yaw_list.reserve(armor_num_);
for (int i = 0; i < armor_num_; i++) {
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
MModel::Measure measure(_ctx);
yaw_list.push_back(measure.get_angle(target_state_.x.data()));
}
return yaw_list;
}
[[nodiscard]] std::vector<Eigen::Vector3d> getArmorPositions() const noexcept {
std::vector<Eigen::Vector3d> armor_positions;
armor_positions.reserve(armor_num_);
for (int i = 0; i < armor_num_; i++) {
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
MModel::Measure measure(_ctx);
const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x);
armor_positions.push_back(xyza.head<3>());
}
return armor_positions;
}
[[nodiscard]] std::vector<Eigen::Vector4d> getArmorPosAndYaw() const noexcept {
std::vector<Eigen::Vector4d> pos_yaw;
pos_yaw.reserve(armor_num_);
for (int i = 0; i < armor_num_; ++i) {
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
MModel::Measure measure(_ctx);
const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x);
pos_yaw.push_back(xyza);
}
return pos_yaw;
}
[[nodiscard]] double getMeanZ() const noexcept {
double mean = 0;
for (const auto& p: getArmorPositions()) {
mean += p.z();
}
return mean / armor_num_;
}
[[nodiscard]] double getArmor2CenterXYDis(int id) const noexcept {
const auto use_l_h = (armor_num_ == 4) && (id == 1 || id == 3);
const auto r = (use_l_h) ? target_state_.r() + target_state_.l() : target_state_.r();
return r;
}
[[nodiscard]] std::vector<std::pair<int, Armor>> match(const std::vector<Armor>& armors
) noexcept;
[[nodiscard]] inline bool checkTargetAppear() const noexcept {
const bool appear = is_tracking
&& wust_vl::common::utils::time_utils::durationSec(
timestamp_,
wust_vl::common::utils::time_utils::now()
) < target_config_->lost_time_thres_param.get();
return appear;
}
[[nodiscard]] Eigen::Matrix<double, MModel::Z_N, 1> getMeasure(const Armor& a) noexcept {
const auto p = a.target_pos;
const double measured_yaw = utils::orientationToYaw<Target>(a.target_ori);
double ypd_y = std::atan2(p.y(), p.x());
static double last_ypd_y = 0;
ypd_y = last_ypd_y + angles::shortest_angular_distance(last_ypd_y, ypd_y);
last_ypd_y = ypd_y;
const double ypd_p = std::atan2(p.z(), std::sqrt(p.x() * p.x() + p.y() * p.y()));
const double ypd_d = std::sqrt(p.x() * p.x() + p.y() * p.y() + p.z() * p.z());
return Eigen::Vector4d(ypd_y, ypd_p, ypd_d, measured_yaw);
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,209 @@
#include "trackerv3.hpp"
namespace wust_vision {
namespace auto_aim {
struct Tracker::Impl {
public:
enum State {
LOST,
DETECTING,
TRACKING,
TEMP_LOST,
} tracker_state = LOST;
Impl(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
tracker_state = LOST;
target_config_ = TargetConfig::create();
auto_aim_config_parameter->registerGroup(*target_config_);
auto_aim_config_parameter->reloadFromOldPath();
target_ = Target();
}
Target track(const Armors& armors_msg) noexcept {
const double dt =
std::chrono::duration<double>(armors_msg.timestamp - last_time_).count();
last_time_ = armors_msg.timestamp;
lost_thres_ =
std::abs(static_cast<int>(target_config_->lost_time_thres_param.get() / dt));
Armors armors;
armors = armors_msg;
std::erase_if(armors.armors, [this](const Armor& a) {
double center_yaw =
std::atan2(target_.target_state_.cy(), target_.target_state_.cx());
bool state_check = tracker_state == TRACKING;
bool outpost_check = target_.tracked_id_ == ArmorNumber::OUTPOST && !a.is_ok;
bool pose_check =
(std::abs(angles::normalize_angle(
orientationToYaw(a.target_ori, center_yaw) - center_yaw
)) > (target_config_->max_yaw_diff_deg_param.get() * M_PI / 180.0)
|| std::abs((a.target_pos - target_.target_state_.pos()).norm())
> target_config_->max_dis_diff_param.get())
&& target_.is_inited
&& std::abs(wust_vl::common::utils::time_utils::durationMs(
target_.timestamp_,
wust_vl::common::utils::time_utils::now()
))
< 1000.0;
return state_check && outpost_check || pose_check;
});
std::sort(
armors.armors.begin(),
armors.armors.end(),
[](const Armor& a, const Armor& b) {
return a.distance_to_image_center < b.distance_to_image_center;
}
);
bool found;
if (tracker_state == LOST) {
found = initTarget(armors);
} else {
found = updateTarget(armors);
}
updateFsm(found);
return target_;
}
void updateFsm(bool found) noexcept {
switch (tracker_state) {
case DETECTING:
if (found) {
if (++detect_count_ > target_config_->tracking_thres_param.get()) {
detect_count_ = 0;
tracker_state = TRACKING;
}
} else {
detect_count_ = 0;
tracker_state = LOST;
}
break;
case TRACKING:
if (!found) {
tracker_state = TEMP_LOST;
lost_count_ = 1;
}
break;
case TEMP_LOST:
if (!found) {
if (++lost_count_ > lost_thres_) {
lost_count_ = 0;
tracker_state = LOST;
}
} else {
lost_count_ = 0;
tracker_state = TRACKING;
}
break;
default:
break;
}
target_.is_tracking = (tracker_state == TRACKING || tracker_state == TEMP_LOST);
if (found)
++found_count_;
}
bool initTarget(const Armors& armors) noexcept {
if (armors.armors.empty()) {
return false;
}
bool found = false;
Armor init_target;
Armors others = armors;
others.armors.clear();
for (auto& a: armors.armors) {
if (!a.is_none_purple && !found) {
init_target = a;
found = true;
continue;
}
others.armors.push_back(a);
}
if (!found) {
return false;
}
target_ = Target(init_target, target_config_);
// updateTarget(others);
tracker_state = DETECTING;
return true;
}
bool updateTarget(const Armors& armors) noexcept {
if (armors.armors.empty())
return false;
target_.predict(armors.timestamp);
std::vector<Armor> candidates;
candidates.reserve(armors.armors.size());
for (const auto& a: armors.armors) {
if (isSameTarget(a.number, target_.tracked_id_) && !a.is_none_purple) {
candidates.emplace_back(a);
}
}
if (candidates.empty())
return false;
int updated = 0;
const auto matches = target_.match(candidates);
for (const auto& m: matches) {
if (m.second.is_none_purple) {
if (++is_none_purple_count_ > 100)
continue;
} else {
is_none_purple_count_ = 0;
}
if (target_.update(m))
++updated;
}
return updated > 0;
}
int lost_thres_;
int detect_count_ = 0;
int lost_count_ = 0;
int is_none_purple_count_ = 0;
int found_count_ = 0;
Target target_;
std::chrono::steady_clock::time_point last_time_;
TargetConfig::Ptr target_config_;
double orientationToYaw(const Eigen::Quaterniond& q, double from) noexcept {
double roll, pitch, yaw;
Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false);
yaw = euler[0];
yaw = from + angles::shortest_angular_distance(from, yaw);
return yaw;
}
};
Tracker::Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
_impl = std::make_unique<Impl>(auto_aim_config_parameter);
}
Tracker::~Tracker() {
_impl.reset();
}
Target Tracker::track(const Armors& armors) noexcept {
return _impl->track(armors);
}
int Tracker::getFoundCount() const noexcept {
return _impl->found_count_;
}
void Tracker::setFoundCount(int count) noexcept {
_impl->found_count_ = count;
}
std::chrono::steady_clock::time_point Tracker::getLastTime() const noexcept {
return _impl->last_time_;
}
void Tracker::setLastTime(std::chrono::steady_clock::time_point t) noexcept {
_impl->last_time_ = t;
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,23 @@
#pragma once
#include "target.hpp"
namespace wust_vision {
namespace auto_aim {
class Tracker {
public:
using Ptr = std::unique_ptr<Tracker>;
Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter);
static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
return std::make_unique<Tracker>(auto_aim_config_parameter);
}
~Tracker();
[[nodiscard]] Target track(const Armors& armors) noexcept;
int getFoundCount() const noexcept;
void setFoundCount(int count) noexcept;
void setLastTime(std::chrono::steady_clock::time_point t) noexcept;
std::chrono::steady_clock::time_point getLastTime() const noexcept;
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,307 @@
// Created by Labor 2023.8.25
// Maintained by Chengfu Zou, Labor
// Copyright (C) FYT Vision Group. All rights reserved.
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "armor_where.hpp"
#include "wust_vl/algorithm/pnp_solver.hpp"
#include <opencv2/core/eigen.hpp>
namespace wust_vision {
namespace auto_aim {
struct ArmorWhere::Impl {
public:
Impl(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
camera_info_ = camera_info;
params_.load(config);
pnp_solver_ = std::make_unique<wust_vl::algorithm::PnPSolver>(cv::SOLVEPNP_IPPE);
pnp_solver_->setObjectPoints(
"small",
ArmorObject::buildObjectPoints<cv::Point3f>(SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT)
);
pnp_solver_->setObjectPoints(
"large",
ArmorObject::buildObjectPoints<cv::Point3f>(LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT)
);
}
struct Params {
enum class OptMode : int { GOLDEN = 0, CERES = 1, NONE = 2 } opt_mode;
OptMode fromString(const std::string& mode) {
if (mode == "golden" || mode == "GOLDEN") {
return OptMode::GOLDEN;
} else if (mode == "none" || mode == "NONE") {
return OptMode::NONE;
} else {
return OptMode::NONE;
}
}
int golden_search_side_deg = 60;
double distance_fix_a2 = 0;
void load(const YAML::Node& node) {
opt_mode = fromString(node["yaw_opt"]["mode"].as<std::string>());
golden_search_side_deg = node["yaw_opt"]["golden_search_side_deg"].as<int>();
distance_fix_a2 = node["distance_fix_a2"].as<double>();
}
} params_;
std::vector<Armor> where(
const std::vector<ArmorObject>& armors,
Eigen::Matrix4d T_camera_to_odom
) const noexcept {
std::vector<Armor> armors_msg;
const Eigen::Matrix3d R_imu_cam = T_camera_to_odom.block<3, 3>(0, 0);
auto makeArmor =
[&](const ArmorObject& obj, const Eigen::Vector3d& t, const Eigen::Matrix3d& R) {
Armor msg;
msg.type = (obj.number == ArmorNumber::NO1 || obj.number == ArmorNumber::BASE)
? "large"
: "small";
msg.number = obj.number;
Eigen::Quaterniond q(R);
Eigen::Quaterniond add_roll {
Eigen::AngleAxisd(M_PI / 2, Eigen::Vector3d::UnitX())
};
Eigen::Quaterniond new_q = q * add_roll;
auto [yaw, pitch, dist] = utils::xyz2ypd_rad(t.x(), t.y(), t.z());
dist += params_.distance_fix_a2 * dist * dist;
auto [x, y, z] = utils::ypd2xyz_rad(yaw, pitch, dist);
msg.pos = { x, y, z };
msg.ori = new_q;
auto_aim::transformArmorData(msg, T_camera_to_odom);
msg.distance_to_image_center =
pnp_solver_->calculateDistanceToCenter(obj.center, camera_info_.first);
msg.is_ok = obj.is_ok;
if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) {
msg.is_none_purple = true;
} else {
msg.is_none_purple = false;
}
return msg;
};
for (auto const& a: armors) {
cv::Mat rvec, tvec;
std::string type = (a.number == ArmorNumber::NO1 || a.number == ArmorNumber::BASE)
? "large"
: "small";
if (!pnp_solver_->solvePnP(
a.landmarks(),
rvec,
tvec,
type,
camera_info_.first,
camera_info_.second
))
{
WUST_WARN("PNP") << "PNP failed";
continue;
}
cv::Mat R_cv;
cv::Rodrigues(rvec, R_cv);
Eigen::Matrix3d R = utils::cvToEigen(R_cv);
Eigen::Vector3d t = utils::cvToEigen(tvec);
if (params_.opt_mode != Params::OptMode::NONE) {
Eigen::Matrix3d R0 = R;
R = solveBa_R(a, t, R0, R_imu_cam, type);
}
armors_msg.push_back(makeArmor(a, t, R));
}
return armors_msg;
}
std::vector<Eigen::Vector2d> reprojectionArmor(
double yaw,
const std::vector<cv::Point3f>& object_points,
const std::vector<cv::Point2f>& landmarks,
const Eigen::Matrix3d& Rci,
double pitch,
double roll,
const Eigen::Vector3d& t
) const noexcept {
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY());
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
const Eigen::Matrix3d R = Rci * (ay * ap * ar).toRotationMatrix();
cv::Mat rvec, R_cv;
cv::eigen2cv(R, R_cv);
cv::Rodrigues(R_cv, rvec);
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << t.x(), t.y(), t.z());
std::vector<cv::Point2f> pts_2d;
pts_2d.reserve(object_points.size());
cv::projectPoints(
object_points,
rvec,
tvec,
camera_info_.first,
camera_info_.second,
pts_2d
);
std::vector<Eigen::Vector2d> image_points;
image_points.reserve(pts_2d.size());
for (const auto& p: pts_2d) {
image_points.emplace_back(p.x, p.y);
}
return image_points;
}
double reprojectionErrorYaw(
double yaw,
const std::vector<cv::Point3f>& object_points,
const std::vector<cv::Point2f>& landmarks,
const std::vector<std::pair<int, int>>& sym_pairs,
const Eigen::Matrix3d& Rci,
double pitch,
double roll,
const Eigen::Vector3d& t
) const noexcept {
const auto image_points =
reprojectionArmor(yaw, object_points, landmarks, Rci, pitch, roll, t);
double cost = 0.0;
// for (size_t i = 0; i < image_points.size(); ++i) {
// Eigen::Vector2d obs(landmarks[i].x, landmarks[i].y);
// cost += (image_points[i] - obs).squaredNorm();
// }
for (auto& p: sym_pairs) {
const Eigen::Vector2d mid = 0.5 * (image_points[p.first] + image_points[p.second]);
const Eigen::Vector2d meas = 0.5
* (Eigen::Vector2d(landmarks[p.first].x, landmarks[p.first].y)
+ Eigen::Vector2d(landmarks[p.second].x, landmarks[p.second].y));
cost += (mid - meas).squaredNorm();
}
return cost;
}
double goldenYaw(
double init,
const std::vector<cv::Point3f>& obj,
const std::vector<cv::Point2f>& lm,
const std::vector<std::pair<int, int>>& sym_pairs,
const Eigen::Matrix3d& Rci,
double pitch,
double roll,
const Eigen::Vector3d& t
) const noexcept {
constexpr double phi = 1.618033988749894848; //(1.0 + std::sqrt(5.0)) * 0.5;
double l = init - params_.golden_search_side_deg * M_PI / 180.0;
double r = init + params_.golden_search_side_deg * M_PI / 180.0;
double y1 = r - (r - l) / phi;
double y2 = l + (r - l) / phi;
double f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t);
double f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t);
while (r - l > 0.0001) {
if (f1 < f2) {
r = y2;
y2 = y1;
f2 = f1;
y1 = r - (r - l) / phi;
f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t);
} else {
l = y1;
y1 = y2;
f1 = f2;
y2 = l + (r - l) / phi;
f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t);
}
}
return 0.5 * (l + r);
}
Eigen::Matrix3d solveBa_R(
const ArmorObject& armor,
const Eigen::Vector3d& t_camera_armor,
const Eigen::Matrix3d& R_camera_armor,
const Eigen::Matrix3d& R_imu_camera,
const std::string& type
) const noexcept {
const Eigen::Matrix3d R_imu_armor = R_imu_camera * R_camera_armor;
const Eigen::Matrix3d R_camera_imu = R_imu_camera.transpose();
//double roll = std::atan2(R_imu_armor(2, 1), R_imu_armor(2, 2));
const double roll = 0;
// initial yaw
const double yaw_init = std::atan2(-R_imu_armor(0, 1), R_imu_armor(1, 1));
const double armor_pitch =
(armor.number == ArmorNumber::OUTPOST) ? -FIFTTEN_DEGREE_RAD : FIFTTEN_DEGREE_RAD;
const Eigen::Vector2d armor_size = (type == "large")
? Eigen::Vector2d { LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT }
: Eigen::Vector2d { SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT };
const auto objPts =
ArmorObject::buildObjectPoints<cv::Point3f>(armor_size.x(), armor_size.y());
const auto& lm = armor.landmarks();
const auto& sym_pairs = ArmorObject::buildSymPairs<int>();
double yaw = yaw_init;
if (params_.opt_mode == Params::OptMode::GOLDEN) {
yaw = goldenYaw(
yaw_init,
objPts,
lm,
sym_pairs,
R_camera_imu,
armor_pitch,
roll,
t_camera_armor
);
}
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
const Eigen::AngleAxisd ap(armor_pitch, Eigen::Vector3d::UnitY());
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
const Eigen::Matrix3d R_result = R_camera_imu * (ay * ap * ar).toRotationMatrix();
return R_result;
}
private:
std::pair<cv::Mat, cv::Mat> camera_info_;
std::unique_ptr<wust_vl::algorithm::PnPSolver> pnp_solver_;
};
ArmorWhere::ArmorWhere(
const YAML::Node& config,
const std::pair<cv::Mat, cv::Mat>& camera_info
) {
_impl = std::make_unique<Impl>(config, camera_info);
}
ArmorWhere::~ArmorWhere() {
_impl.reset();
}
std::vector<Armor> ArmorWhere::where(
const std::vector<ArmorObject>& armors,
Eigen::Matrix4d T_camera_to_odom
) const noexcept {
return _impl->where(armors, T_camera_to_odom);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,26 @@
#pragma once
#include "tasks/auto_aim/type.hpp"
namespace wust_vision {
namespace auto_aim {
class ArmorWhere {
public:
using Ptr = std::unique_ptr<ArmorWhere>;
ArmorWhere(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info);
static Ptr
create(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
return std::make_unique<ArmorWhere>(config, camera_info);
}
~ArmorWhere();
std::vector<Armor> where(
const std::vector<ArmorObject>& armors,
Eigen::Matrix4d T_camera_to_odom
) const noexcept;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,404 @@
#include "auto_aim.hpp"
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp"
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/auto_aim/armor_tracker/trackerv3.hpp"
#include "tasks/auto_aim/armor_where/armor_where.hpp"
#include "tasks/auto_aim/debug.hpp"
#include "tasks/type_common.hpp"
#include "tasks/utils/config.hpp"
#include "wust_vl/common/concurrency/queues.hpp"
namespace wust_vision {
namespace auto_aim {
struct AutoAim::Impl {
~Impl() {
run_flag_ = false;
if (armor_detector_) {
armor_detector_.reset();
}
armor_queue_->stop();
if (processing_thread_) {
processing_thread_->stop();
wust_vl::common::concurrency::ThreadManager::instance().unregisterThread(
processing_thread_->getName()
);
}
}
Impl(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
) {
tf_config_ = tf_config;
camera_info_ = camera_info;
auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create();
auto config = YAML::LoadFile(config_path);
auto_aim_config_parameter_->loadFromFile(config_path);
auto_exposure_cfg_ = AutoExposureCfg::create();
very_aimer_ = VeryAimer::create(auto_aim_config_parameter_);
auto_aim_config_parameter_->registerGroup(*auto_exposure_cfg_);
auto_aim_config_parameter_->registerGroup(*auto_aim_fsm_cl_.config_);
tracker_ = Tracker::create(auto_aim_config_parameter_);
auto_aim_config_parameter_->reloadFromOldPath();
wust_vl::common::utils::ParameterManager::instance().registerParameter(
*auto_aim_config_parameter_.get()
);
max_detect_armors_ = config["max_detect_armors"].as<int>(10);
armor_where_ = ArmorWhere::create(config["armor_where"], camera_info_);
const std::string armor_detect_backend =
config["armor_detect_backend"].as<std::string>("");
armor_detector_ = DetectorFactory::createArmorDetector(armor_detect_backend, true);
armor_detector_->setCallback(std::bind(
&AutoAim::Impl::ArmorDetectCallback,
this,
std::placeholders::_1,
std::placeholders::_2
));
WUST_MAIN(logger_) << "Using Armor Detector: " << armor_detect_backend;
armor_queue_ =
std::make_unique<wust_vl::common::concurrency::OrderedQueue<Armors>>(50, 500);
latency_averager_ =
std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
}
void start() {
if (run_flag_) {
return;
}
run_flag_ = true;
processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create(
"AutoAimProcessingThread",
[this](wust_vl::common::concurrency::MonitoredThread::Ptr self) {
this->processingLoop(self);
}
);
wust_vl::common::concurrency::ThreadManager::instance().registerThread(
processing_thread_
);
}
void pushInput(CommonFrame& frame) {
img_recv_count_++;
auto bbox = target_.expanded(
T_camera_to_odom_,
camera_info_.first,
camera_info_.second,
frame.img_frame.src_img.size()
);
if (bbox.area() > 100) {
frame.expanded = bbox;
frame.offset = cv::Point2f(bbox.x, bbox.y);
}
expanded_ = frame.expanded;
const std::optional<ArmorNumber> target_number = target_.getArmorNumber();
if (armor_detector_) {
armor_detector_->pushInput(frame, target_number);
}
}
void ArmorDetectCallback(const std::vector<ArmorObject>& objs, const CommonFrame& frame) {
std::vector<ArmorObject> sorted_objs = objs;
if (sorted_objs.size() > max_detect_armors_) [[unlikely]] {
WUST_WARN(logger_) << "Detected " << sorted_objs.size()
<< " objects, too many, keeping top " << max_detect_armors_;
std::partial_sort(
sorted_objs.begin(),
sorted_objs.begin() + max_detect_armors_,
sorted_objs.end(),
[](const ArmorObject& a, const ArmorObject& b) {
return a.confidence > b.confidence;
}
);
sorted_objs.resize(max_detect_armors_);
}
for (auto& obj: sorted_objs) {
obj.addOffset(frame.offset);
}
Armors armors;
armors.timestamp = frame.img_frame.timestamp;
armors.id = frame.id;
Eigen::Vector3d v = Eigen::Vector3d::Zero();
Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity();
auto ctx = std::any_cast<VisionCtx>(frame.any_ctx);
std::pair<double, double> gimbal_py;
if (ctx.motion_buffer) {
const auto delay =
std::chrono::microseconds(static_cast<int64_t>(ctx.communication_delay_μs));
const auto t_query = armors.timestamp + delay;
auto apply_motion = [&](const auto& att) {
v << att.data.vx, att.data.vy, att.data.vz;
R_gimbal2odom = Eigen::AngleAxisd(att.data.yaw, Eigen::Vector3d::UnitZ())
* Eigen::AngleAxisd(-att.data.pitch, Eigen::Vector3d::UnitY())
* Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX());
gimbal_py = std::make_pair(att.data.pitch, att.data.yaw);
};
if (auto past_att = ctx.motion_buffer->get_interpolated(t_query)) {
apply_motion(*past_att);
} else if (auto last_att = ctx.motion_buffer->get_last()) {
apply_motion(*last_att);
}
}
autoExposureControl(frame.img_frame.src_img, ctx.camera);
T_camera_to_odom_ = utils::computeCameraToOdomTransform(
R_gimbal2odom,
tf_config_->R_camera2gimbal,
tf_config_->t_camera2gimbal
);
armors.armors = armor_where_->where(sorted_objs, T_camera_to_odom_);
armors.v = v;
for (auto& armor: armors.armors) {
armor.timestamp = armors.timestamp;
}
armor_queue_->enqueue(armors);
++detect_finish_count_;
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
auto& dbg = auto_aim_debug_;
dbg.img_frame = frame.img_frame;
dbg.armors = armors;
dbg.T_camera_to_odom = T_camera_to_odom_;
dbg.detect_color = frame.detect_color;
dbg.armor_objs = sorted_objs;
dbg.expanded = frame.expanded;
dbg.gimbal_py = gimbal_py;
}
}
void armorsCallback(const Armors& armors) {
if (armors.timestamp <= tracker_->getLastTime()) {
WUST_WARN(logger_) << "Received out-of-order armor data, discarded.";
return;
}
Target target = tracker_->track(armors);
auto_aim_fsm_cl_.update(std::abs(target.target_state_.vyaw()), target.jumped);
const auto now = std::chrono::steady_clock::now();
{
std::lock_guard<std::mutex> lock(target_mutex_);
target_ = target;
}
const auto latency_ms =
wust_vl::common::utils::time_utils::durationMs(armors.timestamp, now);
latency_averager_->add(latency_ms);
auto& dbg = auto_aim_debug_;
dbg.latency_ms = latency_averager_->average();
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
dbg.target = target;
dbg.fsm = auto_aim_fsm_cl_.fsm_state_;
}
}
Target getTarget() {
Target target;
{
std::lock_guard<std::mutex> lock(target_mutex_);
target = target_;
}
return target;
}
GimbalCmd solve(double bullet_speed) {
GimbalCmd gimbal_cmd;
Target target;
{
std::lock_guard<std::mutex> lock(target_mutex_);
target = target_;
}
AimTarget aim_target;
const bool appear = target.checkTargetAppear();
if (appear && target.target_state_.pos().norm() > 0.1) {
try {
gimbal_cmd =
very_aimer_->veryAim(target, bullet_speed, auto_aim_fsm_cl_.fsm_state_);
aim_target = gimbal_cmd.aim_target;
} catch (...) {
WUST_ERROR(logger_) << "VeryAim error";
}
}
if (gimbal_cmd.fire_advice) {
fire_count_++;
}
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
auto_aim_debug_.gimbal_cmd = gimbal_cmd;
auto_aim_debug_.aim_target = aim_target;
}
timer_cout_++;
return gimbal_cmd;
}
void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) {
while (!self->isAlive()) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
while (self->isAlive() && run_flag_) {
if (!self->waitPoint())
break;
self->heartbeat();
printStats();
Armors armors;
// bool skip;
// if (armor_queue_->dequeue_wait(armors, skip)) {
// armorsCallback(armors);
// tracker_finish_count_++;
// if (skip) {
// WUST_DEBUG(logger_) << "OrderQueue skip";
// }
// }
if (armor_queue_->try_dequeue(armors)) {
armorsCallback(armors);
tracker_finish_count_++;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
}
void doDebug() {
debug_mode_ = true;
AutoAimDebug dbg;
{
std::lock_guard<std::mutex> lock(dbg_mutex_);
dbg = auto_aim_debug_;
}
drawDebugOverlayShm(dbg, camera_info_, false);
debuglog(dbg);
}
void printStats() {
utils::XSecOnce(
[&] {
double found_ratio = 0.0;
if (img_recv_count_ > 0) {
found_ratio = static_cast<double>(tracker_->getFoundCount())
/ static_cast<double>(img_recv_count_);
}
WUST_INFO(logger_)
<< "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_
<< ", Fin: " << tracker_finish_count_ << ", Tc: " << timer_cout_
<< ", Lat: " << auto_aim_debug_.latency_ms << "ms"
<< ", Fire: " << fire_count_ << ", Found: " << tracker_->getFoundCount()
<< ", Found_ratio: " << found_ratio;
img_recv_count_ = 0;
detect_finish_count_ = 0;
fire_count_ = 0;
tracker_finish_count_ = 0;
timer_cout_ = 0;
tracker_->setFoundCount(0);
},
1.0
);
}
void
autoExposureControl(const cv::Mat& frame, std::shared_ptr<wust_vl::video::Camera> camera) {
const double dt = auto_exposure_cfg_->control_interval_ms_param.get() / 1000.0;
utils::XSecOnce(
[&] {
if (!auto_exposure_cfg_->enable_param.get() || frame.empty()) {
return;
}
if (auto* hik = dynamic_cast<wust_vl::video::HikCamera*>(camera->getDevice())) {
cv::Mat i_use = frame(expanded_);
if (expanded_.area() < 100 || i_use.empty()) {
i_use = frame;
}
const double brightness = utils::computeBrightness(i_use);
const double diff =
brightness - auto_exposure_cfg_->target_brightness_param.get();
const double exposure_min = auto_exposure_cfg_->exposure_min_param.get();
const double exposure_max = auto_exposure_cfg_->exposure_max_param.get();
double exposure_time = hik->getExposureTime();
const double last_exposure_time = exposure_time;
if (std::fabs(diff) > auto_exposure_cfg_->tolerance_param.get()
&& exposure_time > 0.0) {
exposure_time -= diff * auto_exposure_cfg_->step_gain_param.get();
} else {
exposure_time -= auto_exposure_cfg_->decay_step_param.get();
}
if (exposure_time < exposure_min)
exposure_time = exposure_min;
if (exposure_time > exposure_max)
exposure_time = exposure_max;
if (std::abs(exposure_time - last_exposure_time) > 10) {
hik->setExposureTime(exposure_time);
}
}
},
dt
);
}
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_;
Tracker::Ptr tracker_;
ArmorDetectorBase::Ptr armor_detector_;
std::string logger_ = "auto_aim";
std::unique_ptr<wust_vl::common::concurrency::OrderedQueue<Armors>> armor_queue_;
wust_vl::common::concurrency::MonitoredThread::Ptr processing_thread_;
std::unique_ptr<wust_vl::common::utils::Timer> timer_;
VeryAimer::Ptr very_aimer_;
ArmorWhere::Ptr armor_where_;
AutoAimFsmController auto_aim_fsm_cl_;
AutoExposureCfg::Ptr auto_exposure_cfg_;
cv::Rect expanded_;
int max_detect_armors_;
bool run_flag_ = false;
int detect_finish_count_ = 0;
int img_recv_count_ = 0;
int tracker_finish_count_ = 0;
int fire_count_ = 0;
int timer_cout_ = 0;
Target target_;
bool debug_mode_ = false;
AutoAimDebug auto_aim_debug_;
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
TFConfig::Ptr tf_config_;
std::pair<cv::Mat, cv::Mat> camera_info_;
Eigen::Matrix4d T_camera_to_odom_;
std::mutex target_mutex_;
std::mutex dbg_mutex_;
};
AutoAim::AutoAim(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
):
_impl(std::make_unique<Impl>(config_path, tf_config, camera_info, debug)) {}
AutoAim::~AutoAim() {
_impl.reset();
}
void AutoAim::start() {
_impl->start();
}
void AutoAim::pushInput(CommonFrame& frame) {
_impl->pushInput(frame);
}
GimbalCmd AutoAim::solve(double bullet_speed) {
return _impl->solve(bullet_speed);
}
wust_vl::common::concurrency::MonitoredThread::Ptr AutoAim::getThread() {
return _impl->processing_thread_;
}
Target AutoAim::getTarget() {
return _impl->getTarget();
}
void AutoAim::doDebug() {
_impl->doDebug();
}
wust_vl::common::utils::Parameter::Ptr AutoAim::getParameter() {
return _impl->auto_aim_config_parameter_;
}
VeryAimer::Ptr AutoAim::getVeryAimer() {
return _impl->very_aimer_;
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,46 @@
#pragma once
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/imodule.hpp"
#include "tasks/type_common.hpp"
#include "wust_vl/video/camera.hpp"
#include <memory>
namespace wust_vision {
namespace auto_aim {
class AutoAim: public IModule {
public:
using Ptr = std::shared_ptr<AutoAim>;
AutoAim(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
);
static Ptr create(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
) {
return std::make_shared<AutoAim>(config_path, tf_config, camera_info, debug);
}
~AutoAim();
void start() override;
void doDebug() override;
void pushInput(CommonFrame& frame) override;
Target getTarget();
GimbalCmd solve(double bullet_speed) override;
wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override;
wust_vl::common::utils::Parameter::Ptr getParameter();
VeryAimer::Ptr getVeryAimer();
struct Impl;
std::unique_ptr<Impl> _impl;
};
inline AutoAim::Ptr toAutoAim(IModule::Ptr module) {
return std::dynamic_pointer_cast<AutoAim>(module);
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,134 @@
#pragma once
#include "wust_vl/common/utils/parameter.hpp"
namespace wust_vision {
namespace auto_aim {
enum class AutoAimFsm {
AIM_WHOLE_CAR_ARMOR,
AIM_WHOLE_CAR_CENTER,
AIM_SINGLE_ARMOR,
AIM_WHOLE_CAR_PAIR
};
inline std::string auto_aim_fsm_to_string(AutoAimFsm state) {
switch (state) {
case AutoAimFsm::AIM_WHOLE_CAR_ARMOR:
return "AIM_WHOLE_CAR_ARMOR";
case AutoAimFsm::AIM_WHOLE_CAR_CENTER:
return "AIM_WHOLE_CAR_CENTER";
case AutoAimFsm::AIM_SINGLE_ARMOR:
return "AIM_SINGLE_ARMOR";
case AutoAimFsm::AIM_WHOLE_CAR_PAIR:
return "AIM_WHOLE_CAR_PAIR";
default:
return "UNKNOWN";
}
}
class AutoAimFsmController {
public:
AutoAimFsmController() {
config_ = std::make_shared<AutoAimFsmConfig>();
}
AutoAimFsm fsm_state_ { AutoAimFsm::AIM_SINGLE_ARMOR };
struct AutoAimFsmConfig: wust_vl::common::utils::ParamGroup {
public:
static constexpr const char* Logger = "Config: auto_aim::auto_aim_fsm";
static constexpr const char* kKey = "auto_aim_fsm";
const char* key() const override {
return kKey;
}
using Ptr = std::shared_ptr<AutoAimFsmConfig>;
AutoAimFsmConfig() {}
GEN_PARAM(int, transfer_thresh);
GEN_PARAM(double, single_whole_up);
GEN_PARAM(double, single_whole_down);
GEN_PARAM(double, whole_pair_up);
GEN_PARAM(double, whole_pair_down);
GEN_PARAM(double, pair_center_up);
GEN_PARAM(double, pair_center_down);
void loadSelf(const YAML::Node& node) override {
transfer_thresh_param.load(node);
single_whole_up_param.load(node);
single_whole_down_param.load(node);
whole_pair_up_param.load(node);
whole_pair_down_param.load(node);
pair_center_up_param.load(node);
pair_center_down_param.load(node);
}
};
AutoAimFsmConfig::Ptr config_;
int overflow_count_ = 0;
void update(double v_yaw, bool target_jumped) {
// 无跳变:直接退回单装甲,并清状态
if (!target_jumped) {
fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR;
overflow_count_ = 0;
return;
}
const double av = std::abs(v_yaw);
switch (fsm_state_) {
case AutoAimFsm::AIM_SINGLE_ARMOR: {
overflow_count_ =
(av > config_->single_whole_up_param.get()) ? overflow_count_ + 1 : 0;
if (overflow_count_ > config_->transfer_thresh_param.get()) {
fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_ARMOR;
overflow_count_ = 0;
}
break;
}
case AutoAimFsm::AIM_WHOLE_CAR_ARMOR: {
if (av > config_->whole_pair_up_param.get())
++overflow_count_;
else if (av < config_->single_whole_down_param.get())
--overflow_count_;
else
overflow_count_ = 0;
if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) {
fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_PAIR
: AutoAimFsm::AIM_SINGLE_ARMOR;
overflow_count_ = 0;
}
break;
}
case AutoAimFsm::AIM_WHOLE_CAR_PAIR: {
if (av > config_->pair_center_up_param.get())
++overflow_count_;
else if (av < config_->whole_pair_down_param.get())
--overflow_count_;
else
overflow_count_ = 0;
if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) {
fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_CENTER
: AutoAimFsm::AIM_WHOLE_CAR_ARMOR;
overflow_count_ = 0;
}
break;
}
case AutoAimFsm::AIM_WHOLE_CAR_CENTER: {
overflow_count_ =
(av < config_->pair_center_down_param.get()) ? overflow_count_ + 1 : 0;
if (overflow_count_ > config_->transfer_thresh_param.get()) {
fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_PAIR;
overflow_count_ = 0;
}
break;
}
default:
fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR;
overflow_count_ = 0;
break;
}
}
};
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,684 @@
#include "debug.hpp"
namespace wust_vision::auto_aim {
void drawDebugArmorContent(
cv::Mat& debug_img,
const AutoAimDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info
) {
if (debug_img.empty()) {
std::cout << "debug_img is empty" << std::endl;
return;
}
const auto now = std::chrono::steady_clock::now();
const auto& armors = dbg.armors;
const auto& gimbal_cmd = dbg.gimbal_cmd;
const auto& target = dbg.target;
auto aim_target = dbg.aim_target;
const auto& armor_objs = dbg.armor_objs;
const cv::Rect img_rect(0, 0, debug_img.cols, debug_img.rows);
const cv::Rect roi = dbg.expanded & img_rect;
cv::rectangle(debug_img, roi, cv::Scalar(255, 255, 255), 2);
static const int next_indices[] = { 2, 0, 3, 1 };
for (size_t i = 0; i < armor_objs.size(); i++) {
const auto pts = armor_objs[i].toPts();
for (size_t j = 0; j < 4; ++j) {
const cv::Scalar color =
armor_objs[i].is_ok ? cv::Scalar(50, 255, 50) : cv::Scalar(50, 255, 255);
cv::line(debug_img, pts[j], pts[next_indices[j]], color, 2);
}
auto armorName = [](auto_aim::ArmorNumber num) {
switch (num) {
case auto_aim::ArmorNumber::SENTRY:
return "SENTRY";
case auto_aim::ArmorNumber::BASE:
return "BASE";
case auto_aim::ArmorNumber::OUTPOST:
return "OUTPOST";
case auto_aim::ArmorNumber::NO1:
return "NO1";
case auto_aim::ArmorNumber::NO2:
return "NO2";
case auto_aim::ArmorNumber::NO3:
return "NO3";
case auto_aim::ArmorNumber::NO4:
return "NO4";
case auto_aim::ArmorNumber::NO5:
return "NO5";
default:
return "UNKNOWN";
}
};
const std::string armor_str = armorName(armor_objs[i].number);
cv::putText(
debug_img,
armor_str,
pts[1] + cv::Point2f(0, 50),
cv::FONT_HERSHEY_SIMPLEX,
0.7,
cv::Scalar(0, 200, 200),
2
);
}
const std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms);
cv::putText(
debug_img,
latency_str,
cv::Point(10, 30),
cv::FONT_HERSHEY_SIMPLEX,
0.8,
cv::Scalar(255, 255, 255),
2
);
{
// static std::deque<std::pair<Eigen::Vector3d, double>> traj3d;
// double _now =
// std::chrono::duration<double>(std::chrono::steady_clock::now().time_since_epoch()).count();
// traj3d.emplace_back(aim_target.pos, _now);
// while (!traj3d.empty() && _now - traj3d.front().second > 1.0)
// traj3d.pop_front();
aim_target.tf(dbg.T_camera_to_odom.inverse());
const auto pts = aim_target.toPts(camera_info.first, camera_info.second);
if (!pts.empty()) {
// if (traj3d.size() > 1) {
// std::vector<std::pair<cv::Point, double>> img_pts;
// for (auto& p: traj3d) {
// auto p3d_odom = p.first;
// Eigen::Vector4d p_odom(p3d_odom.x(), p3d_odom.y(), p3d_odom.z(), 1);
// Eigen::Vector4d p_camera = dbg.T_camera_to_odom.inverse() * p_odom;
// std::vector<cv::Point3f> obj;
// obj.emplace_back(p_camera.x(), p_camera.y(), p_camera.z());
// std::vector<cv::Point2f> proj;
// cv::projectPoints(
// obj,
// cv::Vec3d(0, 0, 0),
// cv::Vec3d(0, 0, 0),
// camera_info.first,
// camera_info.second,
// proj
// );
// if (!proj.empty()) {
// const auto& pt = proj[0];
// if (std::isfinite(pt.x) && std::isfinite(pt.y)) {
// img_pts.emplace_back(cv::Point(int(pt.x), int(pt.y)), p.second);
// }
// }
// }
// if (img_pts.size() >= 2) {
// double now = std::chrono::duration<double>(
// std::chrono::steady_clock::now().time_since_epoch()
// )
// .count();
// const double max_age = 1.0;
// for (size_t i = 1; i < img_pts.size(); ++i) {
// double age = now - img_pts[i].second;
// double t = std::clamp(age / max_age, 0.0, 1.0);
// int r = int(255 * (1.0 - t));
// int b = int(255 * t);
// cv::Scalar color(b, 0, r);
// cv::line(
// debug_img,
// img_pts[i - 1].first,
// img_pts[i].first,
// color,
// 2,
// cv::LINE_AA
// );
// }
// }
// }
cv::Point2f center(0.f, 0.f);
for (auto pt: pts)
center += pt;
center *= 1.0f / pts.size();
cv::Scalar color(255, 255, 255);
for (int i = 0; i < 4; i++)
cv::line(debug_img, pts[i], pts[(i + 1) % 4], color, 2);
for (int i = 4; i < 8; i++)
cv::line(debug_img, pts[i], pts[4 + (i + 1) % 4], color, 2);
for (int i = 0; i < 4; i++)
cv::line(debug_img, pts[i], pts[i + 4], color, 2);
if (gimbal_cmd.fire_advice) {
int cross_len = 60;
cv::line(
debug_img,
center + cv::Point2f(-cross_len, -cross_len),
center + cv::Point2f(+cross_len, +cross_len),
cv::Scalar(0, 0, 255),
5
);
cv::line(
debug_img,
center + cv::Point2f(-cross_len, +cross_len),
center + cv::Point2f(+cross_len, -cross_len),
cv::Scalar(0, 0, 255),
5
);
}
const double scale = 10.0;
const double v_yaw = gimbal_cmd.v_yaw;
const double v_pitch = gimbal_cmd.v_pitch;
const double dx = -scale * v_yaw;
const double dy = scale * v_pitch;
const cv::Point2f start_pt = center;
const cv::Point2f end_pt = start_pt + cv::Point2f(dx, dy);
const cv::Scalar color_x =
dbg.detect_color ? cv::Scalar(255, 50, 50) : cv::Scalar(50, 50, 255);
cv::arrowedLine(debug_img, start_pt, end_pt, color_x, 4, cv::LINE_AA, 0, 0.2);
}
}
std::vector<cv::Point2f> all_corners;
auto visualizeTargetProjection = [&](auto_aim::Target armor_target) -> auto_aim::Armors {
auto_aim::Armors armor_data;
armor_data.timestamp = armor_target.timestamp_;
if (armor_target.is_tracking) {
Eigen::Vector3d pos = armor_target.target_state_.pos();
if (pos.norm() > 0.5) {
armor_data.armors.clear();
const size_t a_n = armor_target.armor_num_;
armor_data.armors.reserve(a_n);
const auto now = wust_vl::common::utils::time_utils::now();
armor_target.predictSimple(now);
const std::vector<Eigen::Vector4d> armors_posandyaw =
armor_target.getArmorPosAndYaw();
for (size_t i = 0; i < a_n; ++i) {
const Eigen::Vector3d pos = { armors_posandyaw[i][0],
armors_posandyaw[i][1],
armors_posandyaw[i][2] };
Eigen::Vector3d euler;
euler.z() = M_PI / 2.0;
euler.y() = (armor_target.tracked_id_ == auto_aim::ArmorNumber::OUTPOST)
? -0.2618
: 0.2618;
euler.x() = armors_posandyaw[i][3];
const Eigen::Quaterniond ori =
utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
armor_data.armors.emplace_back(auto_aim::Armor {
.type = armor_target.type_,
.pos = pos,
.ori = ori,
.is_ok = true,
.id = (int)(i),
});
}
}
}
return armor_data;
};
auto armor_data = visualizeTargetProjection(dbg.target);
transformArmorData(armor_data, dbg.T_camera_to_odom.inverse());
for (size_t i = 0; i < armor_data.armors.size(); ++i) {
const auto& pts = armor_data.armors[i].toPtsDebug(camera_info.first, camera_info.second);
const auto& pos = armor_data.armors[i].pos;
const auto& ori = armor_data.armors[i].ori;
const auto& id = armor_data.armors[i].id;
cv::Scalar color;
if (dbg.detect_color) {
color = cv::Scalar(255, 0, 0);
} else {
color = cv::Scalar(0, 0, 255);
}
// 绘制前表面
for (size_t j = 0; j < 4; ++j) {
cv::line(debug_img, pts[j], pts[(j + 1) % 4], color, 2);
}
// 绘制后表面
for (size_t j = 4; j < 8; ++j) {
cv::line(debug_img, pts[j], pts[4 + (j + 1) % 4], color, 2);
}
// 绘制侧边
for (size_t j = 0; j < 4; ++j) {
cv::line(debug_img, pts[j], pts[j + 4], color, 2);
}
all_corners.insert(all_corners.end(), pts.begin(), pts.end());
const Eigen::Vector3d euler = ori.toRotationMatrix().eulerAngles(2, 1, 0);
const double yaw = euler[0];
const double distance =
std::sqrt(pos.x() * pos.x() + pos.y() * pos.y() + pos.z() * pos.z());
const std::vector<std::string> info_lines = {
fmt::format("Dis: {:.1f}cm", distance * 100),
fmt::format("X: {:.2f}", pos.x()),
fmt::format("Y: {:.2f}", pos.y()),
fmt::format("Z: {:.2f}", pos.z()),
fmt::format("Yaw: {:.2f}", yaw * 180.0 / M_PI),
fmt::format("ID: {:d}", id)
};
const cv::Point2f text_org = pts[0] + cv::Point2f(0, 200);
for (int k = 0; k < info_lines.size(); ++k) {
cv::putText(
debug_img,
info_lines[k],
text_org + cv::Point2f(0, -10 - 20 * k),
cv::FONT_HERSHEY_SIMPLEX,
0.6,
cv::Scalar(50, 255, 255),
1
);
}
}
if (!all_corners.empty()) {
cv::Point2f avg(0.f, 0.f);
for (const auto& pt: all_corners)
avg += pt;
avg *= 1.0f / all_corners.size();
cv::circle(debug_img, avg, 10, cv::Scalar(50, 255, 50), -1);
const double scale = 50.0;
const double dy = scale * target.target_state_.vyaw();
const cv::Point2f start_pt = avg;
const cv::Point2f end_pt = start_pt + cv::Point2f(0, dy);
cv::arrowedLine(
debug_img,
start_pt,
end_pt,
cv::Scalar(50, 255, 50),
3,
cv::LINE_AA,
0,
0.1
);
cv::putText(
debug_img,
fmt::format("V_yaw: {:.2f}", target.target_state_.vyaw()),
avg + cv::Point2f(0, -20),
cv::FONT_HERSHEY_SIMPLEX,
1.0,
cv::Scalar(50, 255, 50),
2
);
}
std::string state_str;
state_str = auto_aim_fsm_to_string(dbg.fsm);
int baseline = 0;
cv::Size text_size = cv::getTextSize(state_str, cv::FONT_HERSHEY_SIMPLEX, 2.5, 2, &baseline);
// 保证在图像内
const int x =
std::clamp(debug_img.cols - text_size.width - 10, 0, debug_img.cols - text_size.width);
const int y = std::clamp(text_size.height + 10, text_size.height, debug_img.rows - 1);
cv::putText(
debug_img,
state_str,
{ x, y },
cv::FONT_HERSHEY_SIMPLEX,
2.5,
cv::Scalar(0, 0, 255),
2
);
const std::string id_str =
fmt::format("Attack: {}", armorNumberToString(dbg.target.tracked_id_));
const cv::Size id_size = cv::getTextSize(id_str, cv::FONT_HERSHEY_SIMPLEX, 1.6, 2, &baseline);
// 保证在图像内
const int id_x = std::clamp(debug_img.cols - 300, 0, debug_img.cols - id_size.width - 10);
const int id_y = std::clamp(150, id_size.height, debug_img.rows - 1);
cv::putText(
debug_img,
id_str,
{ id_x, id_y },
cv::FONT_HERSHEY_SIMPLEX,
1.6,
cv::Scalar(255, 0, 255),
2
);
if (gimbal_cmd.fire_advice) {
std::string fire_str = "Fire!";
cv::putText(
debug_img,
fire_str,
{ debug_img.cols / 2 - 100, 200 },
cv::FONT_HERSHEY_SIMPLEX,
2.85,
cv::Scalar(0, 0, 255),
2
);
}
const std::string gimbal_str = fmt::format(
"Pitch: {:.2f}, Yaw: {:.2f}, Enable_pitch_diff: {:.2f}, Enable_yaw_diff: {:.2f}, V_yaw: {:.2f}, V_pitch: {:.2f}",
gimbal_cmd.pitch,
gimbal_cmd.yaw,
gimbal_cmd.enable_pitch_diff,
gimbal_cmd.enable_yaw_diff,
gimbal_cmd.v_yaw,
gimbal_cmd.v_pitch
);
cv::putText(
debug_img,
gimbal_str,
{ 10, debug_img.rows - 30 },
cv::FONT_HERSHEY_SIMPLEX,
0.8,
cv::Scalar(255, 255, 0),
2
);
double scale = 100.0;
double armor_len = 0.135;
std::vector<Eigen::Vector2d> pts;
pts.reserve(armors.armors.size() + armor_data.armors.size());
auto collect_xy = [&](auto& list, bool use_target) {
for (auto& a: list)
pts.emplace_back(
use_target ? a.target_pos.x() : a.pos.x(),
use_target ? a.target_pos.y() : a.pos.y()
);
};
collect_xy(armors.armors, true);
collect_xy(armor_data.armors, false);
double max_abs_x = 1e-6, max_abs_y = 1e-6;
for (auto& p: pts) {
max_abs_x = std::max(max_abs_x, std::abs(p.x()));
max_abs_y = std::max(max_abs_y, std::abs(p.y()));
}
const double margin = 200.0;
const double cx = debug_img.cols * 0.5;
const double cy = debug_img.rows * 0.5;
scale = std::min({ (cx - margin) / max_abs_x,
(debug_img.cols - cx - margin) / max_abs_x,
(cy - margin) / max_abs_y,
(debug_img.rows - cy - margin) / max_abs_y,
550.0 });
const cv::Point2d origin(cx, cy);
auto to_img = [&](const Eigen::Vector3d& p) {
return cv::Point2d(origin.x + p.x() * scale, origin.y - p.y() * scale);
};
auto draw2dArmor = [&](const Eigen::Vector3d& pos, double yaw, const cv::Scalar& color) {
cv::Point2d C = to_img(pos);
cv::circle(debug_img, C, 3, color, -1, cv::LINE_AA);
double nx = -sin(yaw), ny = cos(yaw);
double half_len_px = armor_len * 0.5 * scale;
cv::Point2d P1(C.x + nx * half_len_px, C.y - ny * half_len_px);
cv::Point2d P2(C.x - nx * half_len_px, C.y + ny * half_len_px);
cv::line(debug_img, P1, P2, color, 2, cv::LINE_AA);
};
Eigen::Vector3d center(0, 0, 0);
if (!armor_data.armors.empty()) {
for (auto& a: armor_data.armors)
center += a.pos;
center /= armor_data.armors.size();
}
const cv::Point2d Cc = to_img(center);
if (!armor_data.armors.empty())
cv::circle(debug_img, Cc, 5, cv::Scalar(255, 0, 0), -1, cv::LINE_AA);
for (auto& a: armors.armors)
draw2dArmor(a.target_pos, a.yaw, cv::Scalar(0, 255, 255));
std::vector<cv::Point2d> data_pts;
for (auto& a: armor_data.armors) {
double yaw = a.ori.toRotationMatrix().eulerAngles(2, 1, 0)[0];
draw2dArmor(a.pos, yaw, cv::Scalar(255, 255, 255));
data_pts.push_back(to_img(a.pos));
}
for (auto& pt: data_pts)
cv::line(debug_img, Cc, pt, cv::Scalar(180, 180, 255), 1, cv::LINE_AA);
for (auto& a: armors.armors)
cv::line(debug_img, Cc, to_img(a.target_pos), cv::Scalar(0, 150, 255), 1, cv::LINE_AA);
cv::circle(
debug_img,
cv::Point2i(debug_img.cols / 2, debug_img.rows / 2),
5,
cv::Scalar(255, 255, 255),
2
);
}
void writeTargetLogToJson(const auto_aim::Target& armor_target) {
nlohmann::json j;
// -------- armor_target 部分 --------
nlohmann::json jt;
jt["type"] = armor_target.type_;
jt["tracking"] = armor_target.is_tracking;
jt["id"] = static_cast<int>(armor_target.tracked_id_);
jt["armors_num"] = armor_target.armor_num_;
const auto now = std::chrono::steady_clock::now();
const auto age_ms_t =
std::chrono::duration_cast<std::chrono::milliseconds>(now - armor_target.timestamp_)
.count();
jt["timestamp_age_ms"] = age_ms_t;
jt["position"] = { { "x", armor_target.target_state_.cx() },
{ "y", armor_target.target_state_.cy() },
{ "z", armor_target.target_state_.cz() } };
jt["velocity"] = { { "x", armor_target.target_state_.vcx() },
{ "y", armor_target.target_state_.vcy() },
{ "z", armor_target.target_state_.vcz() } };
jt["r"] = armor_target.target_state_.r();
jt["l"] = armor_target.target_state_.l();
jt["h"] = armor_target.target_state_.h();
jt["yaw"] = armor_target.target_state_.yaw();
jt["v_yaw"] = armor_target.target_state_.vyaw();
j["armor_target"] = jt;
// -------- 写文件 --------
std::ofstream file("/dev/shm/target_log.json");
if (file.is_open()) {
file << j.dump(2);
}
}
struct DebugLogs {
#define DEBUG_LOG_LIST(X) \
X(double, 100, time) \
X(double, 100, raw_yaw) \
X(double, 100, raw_pitch) \
X(double, 100, yaw) \
X(double, 100, pitch) \
X(double, 100, armor_dis) \
X(double, 100, armor_x) \
X(double, 100, armor_y) \
X(double, 100, armor_z) \
X(double, 100, armor_yaw) \
X(double, 100, ypd_y) \
X(double, 100, ypd_p) \
X(double, 100, gimbal_yaw) \
X(double, 100, gimbal_pitch) \
X(double, 100, target_v_yaw) \
X(double, 100, control_v_yaw) \
X(double, 100, control_v_pitch) \
X(double, 100, yaw_diff) \
X(double, 100, fire) \
X(double, 100, rune_dis) \
X(double, 100, fly_time) \
X(double, 100, control_a_yaw) \
X(double, 100, control_a_pitch)
#define GEN_LOG(TYPE, SIZE, NAME) LogsStream<TYPE, SIZE> NAME##_log { #NAME };
#define X(TYPE, SIZE, NAME) GEN_LOG(TYPE, SIZE, NAME)
DEBUG_LOG_LIST(X)
#undef X
void clear() {
#define X(TYPE, SIZE, NAME) NAME##_log.clear();
DEBUG_LOG_LIST(X)
#undef X
}
};
void debuglog(const AutoAimDebug& dbg_armor) {
static bool first_log = true;
static std::chrono::steady_clock::time_point start_time;
static auto_aim::Armor last_armor_;
static double last_armor_yaw_ = 0.0;
static double last_ypd_y_ = 0.0;
static double last_ypd_p_ = 0.0;
static double last_distance_ = 0.0;
static DebugLogs log;
static GimbalCmd last_cmd_;
static double rune_dis = 0.0;
if (first_log) {
start_time = std::chrono::steady_clock::now();
first_log = false;
}
const auto now = std::chrono::steady_clock::now();
const auto_aim::Armors& armors = dbg_armor.armors;
const double t = std::chrono::duration<double>(now - start_time).count();
const auto_aim::Target& target = dbg_armor.target;
writeTargetLogToJson(target);
double armor_yaw = 0.0, ypd_y = 0.0, ypd_p = 0.0, armor_distance = 0.0;
if (!armors.armors.empty()) {
std::vector<auto_aim::Armor> ok_armors;
for (const auto& armor: armors.armors) {
if (armor.number != auto_aim::ArmorNumber::OUTPOST)
ok_armors.push_back(armor);
}
if (!ok_armors.empty()) {
const auto_aim::Armor& min_armor = *std::min_element(
ok_armors.begin(),
ok_armors.end(),
[](const auto_aim::Armor& a, const auto_aim::Armor& b) {
return a.distance_to_image_center < b.distance_to_image_center;
}
);
last_armor_ = min_armor;
armor_distance = std::hypot(
min_armor.target_pos.x(),
min_armor.target_pos.y(),
min_armor.target_pos.z()
);
auto orientationToYaw = [](const Eigen::Quaterniond& q) noexcept -> double {
Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false);
double yaw = euler[0];
yaw = last_armor_yaw_ + angles::shortest_angular_distance(last_armor_yaw_, yaw);
last_armor_yaw_ = yaw;
return yaw;
};
armor_yaw = orientationToYaw(min_armor.target_ori);
ypd_y = std::atan2(min_armor.target_pos.y(), min_armor.target_pos.x());
ypd_y = last_ypd_y_ + angles::shortest_angular_distance(last_ypd_y_, ypd_y);
last_ypd_y_ = ypd_y;
ypd_p = std::atan2(
min_armor.target_pos.z(),
std::hypot(min_armor.target_pos.x(), min_armor.target_pos.y())
);
last_ypd_p_ = ypd_p;
last_distance_ = armor_distance;
}
}
GimbalCmd i_use;
if (dbg_armor.gimbal_cmd.appear) {
i_use = dbg_armor.gimbal_cmd;
} else {
i_use = last_cmd_;
}
last_cmd_ = i_use;
nlohmann::json j;
log.time_log.handleOnce(t, j);
log.raw_yaw_log.handleOnce(i_use.target_yaw, j);
log.raw_pitch_log.handleOnce(i_use.target_pitch, j);
log.yaw_log.handleOnce(i_use.yaw, j);
log.pitch_log.handleOnce(i_use.pitch, j);
log.armor_yaw_log.handleOnce(armor_yaw * 180.0 / M_PI, j);
log.armor_x_log.handleOnce(last_armor_.target_pos.x(), j);
log.armor_y_log.handleOnce(last_armor_.target_pos.y(), j);
log.armor_z_log.handleOnce(last_armor_.target_pos.z(), j);
log.ypd_y_log.handleOnce(last_ypd_y_ * 180.0 / M_PI, j);
log.ypd_p_log.handleOnce(last_ypd_p_ * 180.0 / M_PI, j);
log.armor_dis_log.handleOnce(last_distance_, j);
log.gimbal_pitch_log.handleOnce(dbg_armor.gimbal_py.first * 180.0 / M_PI, j);
log.gimbal_yaw_log.handleOnce(dbg_armor.gimbal_py.second * 180.0 / M_PI, j);
log.target_v_yaw_log.handleOnce(target.target_state_.vyaw(), j);
log.control_v_pitch_log.handleOnce(i_use.v_pitch, j);
log.control_v_yaw_log.handleOnce(i_use.v_yaw, j);
log.fire_log.handleOnce(i_use.fire_advice, j);
log.rune_dis_log.handleOnce(rune_dis, j);
log.fly_time_log.handleOnce(i_use.fly_time, j);
log.control_a_yaw_log.handleOnce(i_use.a_yaw / 180.0 * M_PI, j);
log.control_a_pitch_log.handleOnce(i_use.a_pitch / 180.0 * M_PI, j);
log.yaw_diff_log.handleOnce(
std::abs(dbg_armor.gimbal_py.second * 180.0 / M_PI - dbg_armor.gimbal_cmd.yaw),
j
);
std::ofstream file("/dev/shm/cmd_log.json");
if (file.is_open()) {
file << j.dump();
}
}
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,38 @@
#pragma once
#include "tasks/auto_aim/armor_tracker/target.hpp"
#include "tasks/auto_aim/auto_aim_fsm.hpp"
#include "tasks/auto_aim/type.hpp"
#include "tasks/utils/debug_utils.hpp"
namespace wust_vision::auto_aim {
struct AutoAimDebug {
wust_vl::video::ImageFrame img_frame;
auto_aim::Armors armors;
auto_aim::Target target;
GimbalCmd gimbal_cmd;
auto_aim::AutoAimFsm fsm;
AimTarget aim_target;
double latency_ms;
Eigen::Matrix4d T_camera_to_odom;
std::vector<auto_aim::ArmorObject> armor_objs;
int detect_color = 0;
cv::Rect expanded;
std::pair<double, double> gimbal_py;
};
void drawDebugArmorContent(
cv::Mat& debug_img,
const AutoAimDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info
);
void writeTargetLogToJson(const auto_aim::Target& armor_target);
inline void drawDebugOverlayShm(
const AutoAimDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info,
bool auto_fps
) {
static ShmWriter shm { "/debug_frame" };
drawDebugOverlayImpl(dbg, camera_info, auto_fps, drawDebugArmorContent, shm);
}
void debuglog(const AutoAimDebug& dbg_armor);
} // namespace wust_vision::auto_aim

View File

@@ -0,0 +1,399 @@
#include "type.hpp"
#include "tasks/utils/config.hpp"
#include "wust_vl/common/utils/logger.hpp"
#include <numeric>
namespace wust_vision {
namespace auto_aim {
Light::Light(const std::vector<cv::Point>& contour): cv::RotatedRect(cv::minAreaRect(contour)) {
this->center = std::accumulate(
contour.begin(),
contour.end(),
cv::Point2f(0, 0),
[n = static_cast<float>(contour.size())](const cv::Point2f& a, const cv::Point& b) {
return a + cv::Point2f(b.x, b.y) / n;
}
);
cv::Point2f p[4];
this->points(p);
std::sort(p, p + 4, [](const cv::Point2f& a, const cv::Point2f& b) { return a.y < b.y; });
top = (p[0] + p[1]) / 2;
bottom = (p[2] + p[3]) / 2;
length = cv::norm(top - bottom);
width = cv::norm(p[0] - p[1]);
axis = (top - bottom) / cv::norm(top - bottom);
tilt_angle =
std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f;
}
void Light::addOffset(const cv::Point2f& offset) noexcept {
this->center += offset;
top += offset;
bottom += offset;
}
void Light::transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept {
top = utils::transformPoint2D(transform_matrix, top);
bottom = utils::transformPoint2D(transform_matrix, bottom);
length = cv::norm(top - bottom);
cv::Point2f p[4];
this->points(p);
width = cv::norm(
utils::transformPoint2D(transform_matrix, p[0])
- utils::transformPoint2D(transform_matrix, p[1])
);
const cv::Point2f p0 = center;
const cv::Point2f p1 = center + axis;
const cv::Point2f p0_t = utils::transformPoint2D(transform_matrix, p0);
const cv::Point2f p1_t = utils::transformPoint2D(transform_matrix, p1);
axis = p1_t - p0_t;
axis /= cv::norm(axis);
tilt_angle =
std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f;
center = utils::transformPoint2D(transform_matrix, center);
}
int formArmorColor(const ArmorColor& color) noexcept {
switch (color) {
case ArmorColor::RED:
return 0;
case ArmorColor::BLUE:
return 1;
case ArmorColor::NONE:
return 2;
case ArmorColor::PURPLE:
return 3;
}
return -1;
}
std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept {
switch (number) {
case ArmorNumber::SENTRY:
return os << "SENTRY";
case ArmorNumber::NO1:
return os << "NO1";
case ArmorNumber::NO2:
return os << "NO2";
case ArmorNumber::NO3:
return os << "NO3";
case ArmorNumber::NO4:
return os << "NO4";
case ArmorNumber::NO5:
return os << "NO5";
case ArmorNumber::OUTPOST:
return os << "OUTPOST";
case ArmorNumber::BASE:
return os << "BASE";
case ArmorNumber::UNKNOWN:
return os << "UNKNOWN";
default:
return os << "InvalidArmorNumber(" << static_cast<int>(number) << ")";
}
}
int formArmorNumber(const ArmorNumber& number) noexcept {
switch (number) {
case ArmorNumber::SENTRY:
return 0;
case ArmorNumber::NO1:
return 1;
case ArmorNumber::NO2:
return 2;
case ArmorNumber::NO3:
return 3;
case ArmorNumber::NO4:
return 4;
case ArmorNumber::NO5:
return 5;
case ArmorNumber::OUTPOST:
return 6;
case ArmorNumber::BASE:
return 7;
case ArmorNumber::UNKNOWN:
return 8;
}
return -1;
}
ArmorNumber armorNumberFromString(const std::string& s) noexcept {
if (s == "SENTRY")
return ArmorNumber::SENTRY;
if (s == "BASE")
return ArmorNumber::BASE;
if (s == "OUTPOST")
return ArmorNumber::OUTPOST;
if (s == "NO1")
return ArmorNumber::NO1;
if (s == "NO2")
return ArmorNumber::NO2;
if (s == "NO3")
return ArmorNumber::NO3;
if (s == "NO4")
return ArmorNumber::NO4;
if (s == "NO5")
return ArmorNumber::NO5;
return ArmorNumber::UNKNOWN;
}
std::string armorNumberToString(const ArmorNumber& num) noexcept {
switch (num) {
case ArmorNumber::SENTRY:
return "SENTRY";
case ArmorNumber::BASE:
return "BASE";
case ArmorNumber::OUTPOST:
return "OUTPOST";
case ArmorNumber::NO1:
return "NO1";
case ArmorNumber::NO2:
return "NO2";
case ArmorNumber::NO3:
return "NO3";
case ArmorNumber::NO4:
return "NO4";
case ArmorNumber::NO5:
return "NO5";
default:
return "UNKNOWN";
}
}
namespace {
std::unordered_map<std::string, int> armor_map;
std::unordered_map<int, std::vector<ArmorNumber>> tracker_to_armors;
bool loaded = false;
void loadArmorMapOnce() {
if (loaded)
return;
try {
YAML::Node config = YAML::LoadFile(AUTO_AIM_CONFIG)["armor_map"];
for (auto it = config.begin(); it != config.end(); ++it) {
const std::string key = it->first.as<std::string>();
const int tracker_id = it->second.as<int>();
ArmorNumber armor_num = armorNumberFromString(key);
armor_map[key] = tracker_id;
tracker_to_armors[tracker_id].emplace_back(armor_num);
}
loaded = true;
} catch (const std::exception& e) {
std::cerr << "[ArmorMap] Failed to load armor_map.yaml: " << e.what() << std::endl;
}
}
} // namespace
int retypetotracker(const ArmorNumber& a) noexcept {
loadArmorMapOnce();
const std::string key = armorNumberToString(a);
const auto it = armor_map.find(key);
if (it != armor_map.end())
return it->second;
std::cerr << "[retypetotracker] Invalid ArmorNumber: " << static_cast<int>(a) << std::endl;
return -1;
}
bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept {
return retypetotracker(a) == retypetotracker(b);
}
std::string armorTypeToString(const ArmorType& type) noexcept {
switch (type) {
case ArmorType::SMALL:
return "small";
case ArmorType::LARGE:
return "large";
default:
return "invalid";
}
}
std::vector<cv::Point2f> ArmorObject::toPts() const noexcept {
if (is_ok) {
return { lights[0].top, lights[0].bottom, lights[1].bottom, lights[1].top };
} else {
return { pts[0], pts[1], pts[2], pts[3] };
}
}
bool ArmorObject::checkOkptsRight(double max_error) const noexcept {
double error = 0.0;
for (int i = 0; i < 4; i++) {
error += cv::norm(pts[i] - toPts()[i]);
}
return error < max_error;
}
std::array<cv::Point2f, 4> ArmorObject::sortCorners(const std::vector<cv::Point2f>& pts
) const noexcept {
std::array<cv::Point2f, 4> ordered;
// 先按 x 坐标分成左右两组
std::vector<cv::Point2f> left, right;
std::vector<cv::Point2f> sorted = pts;
std::sort(sorted.begin(), sorted.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
return a.x < b.x;
});
left.push_back(sorted[0]);
left.push_back(sorted[1]);
right.push_back(sorted[2]);
right.push_back(sorted[3]);
// 左边两个点,按 y 分为上/下
std::sort(left.begin(), left.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
return a.y < b.y;
});
ordered[1] = left[0]; // 左上
ordered[0] = left[1]; // 左下
// 右边两个点,按 y 分为上/下
std::sort(right.begin(), right.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
return a.y < b.y;
});
ordered[2] = right[0]; // 右上
ordered[3] = right[1]; // 右下
return ordered; // 顺序: 左下, 左上, 右上, 右下
}
std::vector<cv::Point2f> ArmorObject::landmarks() const noexcept {
if constexpr (N_LANDMARKS == 4) {
if (is_ok) {
return { lights[0].bottom, lights[0].top, lights[1].top, lights[1].bottom };
} else {
const auto ordered = sortCorners(pts);
return { ordered[0], ordered[1], ordered[2], ordered[3] };
}
} else {
if (is_ok) {
return { lights[0].bottom, lights[0].center, lights[0].top,
lights[1].top, lights[1].center, lights[1].bottom };
} else {
const auto ordered = sortCorners(pts);
return { ordered[0], (ordered[0] + ordered[1]) / 2.0, ordered[1],
ordered[2], (ordered[2] + ordered[3]) / 2.0, ordered[3] };
}
}
}
ArmorObject::ArmorObject(const Light& l1, const Light& l2) {
pts.resize(4);
if (l1.center.x < l2.center.x) {
lights.push_back(l1);
lights.push_back(l2);
pts[0] = l1.top;
pts[1] = l1.bottom;
pts[2] = l2.bottom;
pts[3] = l2.top;
} else {
lights.push_back(l2);
lights.push_back(l1);
pts[0] = l2.top;
pts[1] = l2.bottom;
pts[2] = l1.bottom;
pts[3] = l1.top;
}
is_ok = true;
}
std::vector<cv::Point2f> Armor::toPtsDebug(
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion
) const noexcept {
std::vector<cv::Point2f> image_points;
const std::vector<cv::Point3f>* model_points;
static std::vector<cv::Point3f> SMALL_ARMOR_3D_POINTS_BLOCK = {
{ 0, 0.025, -0.066 }, // 左上前
{ 0, -0.025, -0.066 }, // 左下前
{ 0, -0.025, 0.066 }, // 右下前
{ 0, 0.025, 0.066 }, // 右上前
{ 0.015, 0.025, -0.066 }, // 左上后
{ 0.015, -0.025, -0.066 }, // 左下后
{ 0.015, -0.025, 0.066 }, // 右下后
{ 0.015, 0.025, 0.066 }, // 右上后
};
static std::vector<cv::Point3f> BIG_ARMOR_3D_POINTS_BLOCK = {
{ 0, 0.025, -0.1125 }, { 0, -0.025, -0.1125 }, { 0, -0.025, 0.1125 },
{ 0, 0.025, 0.1125 }, { 0.015, 0.025, -0.1125 }, { 0.015, -0.025, -0.1125 },
{ 0.015, -0.025, 0.1125 }, { 0.015, 0.025, 0.1125 },
};
if (type == "large") {
model_points = &BIG_ARMOR_3D_POINTS_BLOCK;
} else if (type == "small") {
model_points = &SMALL_ARMOR_3D_POINTS_BLOCK;
}
const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix();
const cv::Mat rot_mat =
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
tf_rot(0, 1),
tf_rot(0, 2),
tf_rot(1, 0),
tf_rot(1, 1),
tf_rot(1, 2),
tf_rot(2, 0),
tf_rot(2, 1),
tf_rot(2, 2));
// 旋转矩阵 -> 旋转向量
cv::Mat rvec;
cv::Rodrigues(rot_mat, rvec);
// 平移向量
const cv::Mat tvec =
(cv::Mat_<double>(3, 1) << target_pos.x(), target_pos.y(), target_pos.z());
// 反投影
cv::projectPoints(
*model_points,
rvec,
tvec,
camera_intrinsic,
camera_distortion,
image_points
);
return image_points;
}
void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept {
for (auto& armor: armors.armors) {
transformArmorData(armor, T_camera_to_odom);
}
}
void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept {
try {
// 位置
const Eigen::Vector3d pos_camera = armor.pos;
armor.target_pos = utils::transformPosition(pos_camera, T_camera_to_odom);
// 姿态
const Eigen::Quaterniond
q_camera(armor.ori.w(), armor.ori.x(), armor.ori.y(), armor.ori.z());
const Eigen::Quaterniond q_odom =
utils::transformOrientation(q_camera, T_camera_to_odom);
armor.target_ori = q_odom;
// 提取 yaw
const Eigen::Vector3d euler = q_odom.toRotationMatrix().eulerAngles(2, 1, 0); // ZYX
armor.yaw = euler[0]; // yaw
} catch (const std::exception& e) {
WUST_ERROR("tf") << "Error in camera-to-odom transform: " << e.what();
}
}
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,172 @@
#pragma once
#include "tasks/type_common.hpp"
#include "tasks/utils/utils.hpp"
namespace wust_vision {
namespace auto_aim {
constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135
constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0;
constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180;
struct Light: public cv::RotatedRect {
Light() = default;
explicit Light(const std::vector<cv::Point>& contour);
void addOffset(const cv::Point2f& offset) noexcept;
void transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept;
cv::Point2f top, bottom;
int color = 0;
cv::Point2f axis;
double length = 0;
double width = 0;
float tilt_angle = 0;
};
enum class ArmorColor : int { BLUE = 0, RED, NONE, PURPLE };
int formArmorColor(const ArmorColor& color) noexcept;
enum class ArmorNumber : int { SENTRY = 0, NO1, NO2, NO3, NO4, NO5, OUTPOST, BASE, UNKNOWN };
std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept;
int formArmorNumber(const ArmorNumber& number) noexcept;
std::string armorNumberToString(const ArmorNumber& num) noexcept;
ArmorNumber armorNumberFromString(const std::string& s) noexcept;
int retypetotracker(const ArmorNumber& a) noexcept;
bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept;
enum class ArmorsNum { NORMAL_4 = 4, OUTPOST_3 = 3 };
enum class ArmorType { SMALL, LARGE, INVALID };
std::string armorTypeToString(const ArmorType& type) noexcept;
struct ArmorObject {
ArmorColor color;
ArmorNumber number;
std::vector<cv::Point2f> pts;
cv::Rect box;
cv::Mat number_img;
double confidence;
cv::Mat whole_binary_img;
cv::Mat whole_rgb_img;
cv::Mat whole_gray_img;
std::vector<Light> lights;
cv::Point2f local_offset;
cv::Point2f center;
bool is_ok = false;
ArmorType type;
static constexpr const int N_LANDMARKS = 6;
static constexpr const int N_LANDMARKS_2 = N_LANDMARKS * 2;
template<typename PointType>
static std::vector<PointType> buildObjectPoints(const double& w, const double& h) noexcept {
if constexpr (N_LANDMARKS == 4) {
return {
PointType(0, w / 2, -h / 2), // 右下
PointType(0, w / 2, h / 2), // 右上
PointType(0, -w / 2, h / 2), // 左上
PointType(0, -w / 2, -h / 2) // 左下
};
} else {
return {
PointType(0, w / 2, -h / 2), // 右下
PointType(0, w / 2, 0.0), // 右中
PointType(0, w / 2, h / 2), // 右上
PointType(0, -w / 2, h / 2), // 左上
PointType(0, -w / 2, 0.0), // 左中
PointType(0, -w / 2, -h / 2) // 左下
};
}
}
template<typename IDType>
static std::vector<std::pair<IDType, IDType>> buildSymPairs() noexcept {
if constexpr (N_LANDMARKS == 4) {
static const std::vector<std::pair<IDType, IDType>> pairs = {
{ 0, 3 },
{ 1, 2 },
// { 0, 2 },
// { 1, 3 }
};
return pairs;
} else {
static const std::vector<std::pair<IDType, IDType>> pairs = {
{ 0, 5 },
{ 1, 4 },
{ 2, 3 },
// { 0, 3 },
// { 2, 5 }
};
return pairs;
}
}
std::vector<cv::Point2f> toPts() const noexcept;
bool checkOkptsRight(double max_error) const noexcept;
std::array<cv::Point2f, 4> sortCorners(const std::vector<cv::Point2f>& pts) const noexcept;
// Landmarks start from bottom left in clockwise order
std::vector<cv::Point2f> landmarks() const noexcept;
void addOffset(const cv::Point2f& offset) noexcept {
for (auto& pt: pts) {
pt += offset;
}
center += offset;
box.x += offset.x;
box.y += offset.y;
for (auto& l: lights) {
l.addOffset(offset);
}
}
void transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept {
for (auto& l: lights) {
l.transform(transform_matrix);
}
center = utils::transformPoint2D(transform_matrix, center);
box = utils::transformRect(transform_matrix, box);
for (auto& pt: pts) {
pt = utils::transformPoint2D(transform_matrix, pt);
}
}
ArmorObject(const Light& l1, const Light& l2);
ArmorObject() = default;
};
struct Armor {
public:
ArmorNumber number;
std::string type;
Eigen::Vector3d pos;
Eigen::Quaterniond ori;
Eigen::Vector3d target_pos;
Eigen::Quaterniond target_ori;
float distance_to_image_center;
float yaw;
std::chrono::steady_clock::time_point timestamp;
bool is_ok = false;
bool is_none_purple = false;
int id = -1;
std::vector<cv::Point2f> toPtsDebug(
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion
) const noexcept;
};
struct Armors {
public:
std::vector<Armor> armors;
std::chrono::steady_clock::time_point timestamp;
int id;
Eigen::Vector3d v;
};
void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept;
void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept;
} // namespace auto_aim
} // namespace wust_vision

View File

@@ -0,0 +1,360 @@
#include "auto_buff.hpp"
#include "tasks/auto_buff/debug.hpp"
#include "tasks/auto_buff/rune_control/aimer.hpp"
#include "tasks/auto_buff/rune_detector/rune_detector.hpp"
#include "tasks/auto_buff/rune_tracker/rune_tracker.hpp"
#include "tasks/auto_buff/rune_where/rune_where.hpp"
#include "tasks/type_common.hpp"
#include "tasks/utils/utils.hpp"
namespace wust_vision {
namespace auto_buff {
struct AutoBuff::Impl {
~Impl() {
run_flag_ = false;
rune_queue_->stop();
if (processing_thread_) {
processing_thread_->stop();
wust_vl::common::concurrency::ThreadManager::instance().unregisterThread(
processing_thread_->getName()
);
}
}
Impl(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
) {
debug_mode_ = debug;
auto_buff_config_parameter_ = wust_vl::common::utils::Parameter::create();
auto_buff_config_parameter_->loadFromFile(config_path);
auto_exposure_cfg_ = AutoExposureCfg::create();
aimer_ = auto_buff::Aimer::create(auto_buff_config_parameter_);
rune_tracker_ = RuneTracker::create(auto_buff_config_parameter_);
auto_buff_config_parameter_->registerGroup(*auto_exposure_cfg_);
auto_buff_config_parameter_->reloadFromOldPath();
auto config = YAML::LoadFile(config_path);
wust_vl::common::utils::ParameterManager::instance().registerParameter(
*auto_buff_config_parameter_.get()
);
tf_config_ = tf_config;
camera_info_ = camera_info;
rune_where_ = auto_buff::RuneWhere::create(config["rune_where"], camera_info);
rune_detector_ = RuneDetectorCV::make_detector(config["rune_detector"]);
rune_detector_->setCallback(std::bind(
&AutoBuff::Impl::runeDetectCallback,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3
));
rune_queue_ =
std::make_unique<wust_vl::common::concurrency::OrderedQueue<auto_buff::RuneFan>>(
50,
500
);
latency_averager_ =
std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
}
void start() {
if (run_flag_) {
return;
}
run_flag_ = true;
processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create(
"AutoBuffProcessingThread",
[this](wust_vl::common::concurrency::MonitoredThread::Ptr self) {
this->processingLoop(self);
}
);
wust_vl::common::concurrency::ThreadManager::instance().registerThread(
processing_thread_
);
}
void pushInput(CommonFrame& frame) {
img_recv_count_++;
auto bbox = rune_target_.expanded(
T_camera_to_odom_,
camera_info_.first,
camera_info_.second,
frame.img_frame.src_img.size()
);
if (bbox.area() > 100) {
frame.expanded = bbox;
frame.offset = cv::Point2f(bbox.x, bbox.y);
}
expanded_ = frame.expanded;
rune_detector_->pushInput(frame, debug_mode_);
}
void runeDetectCallback(
const auto_buff::RuneFan& fan,
const CommonFrame& frame,
cv::Mat& debug_img
) {
std::lock_guard<std::mutex> lock(callback_mutex_);
Eigen::Vector3d v = Eigen::Vector3d::Zero();
Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity();
auto ctx = std::any_cast<VisionCtx>(frame.any_ctx);
std::pair<double, double> gimbal_py;
if (ctx.motion_buffer) {
const auto delay =
std::chrono::microseconds(static_cast<int64_t>(ctx.communication_delay_μs));
const auto t_query = fan.timestamp + delay;
auto apply_motion = [&](const auto& att) {
v << att.data.vx, att.data.vy, att.data.vz;
R_gimbal2odom = Eigen::AngleAxisd(att.data.yaw, Eigen::Vector3d::UnitZ())
* Eigen::AngleAxisd(-att.data.pitch, Eigen::Vector3d::UnitY())
* Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX());
gimbal_py = std::make_pair(att.data.pitch, att.data.yaw);
};
if (auto past_att = ctx.motion_buffer->get_interpolated(t_query)) {
apply_motion(*past_att);
} else if (auto last_att = ctx.motion_buffer->get_last()) {
apply_motion(*last_att);
}
}
autoExposureControl(frame.img_frame.src_img, ctx.camera);
Eigen::Matrix4d T_camera_to_odom = utils::computeCameraToOdomTransform(
R_gimbal2odom,
tf_config_->R_camera2gimbal,
tf_config_->t_camera2gimbal
);
T_camera_to_odom_ = T_camera_to_odom;
auto_buff::RuneFan copy_fan = rune_where_->where(fan, T_camera_to_odom);
copy_fan.is_big =
InfantryMode::toAttackMode(ctx.mode) == InfantryMode::AttackMode::BIG_RUNE;
rune_queue_->enqueue(copy_fan);
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
auto_buff_debug_.img_frame = frame.img_frame;
auto_buff_debug_.T_camera_to_odom = T_camera_to_odom_;
auto_buff_debug_.expanded = frame.expanded;
auto_buff_debug_.pnp_distance =
copy_fan.fans.empty() ? 0.0 : copy_fan.fans[0].pos.norm();
auto_buff_debug_.gimbal_py = gimbal_py;
}
detect_finish_count_++;
}
void runeTargetCallback(const auto_buff::RuneFan& fan) {
if (fan.timestamp <= last_rune_target_time_) {
WUST_WARN(logger_) << "Received out-of-order auto_buff data, discarded.";
return;
}
last_rune_target_time_ = fan.timestamp;
auto rune_target = rune_tracker_->track(fan);
{
std::lock_guard<std::mutex> lock(target_mutex_);
rune_target_ = rune_target;
}
auto now = std::chrono::steady_clock::now();
auto latency_ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now() - fan.timestamp
)
.count();
auto latency_ms = wust_vl::common::utils::time_utils::durationMs(fan.timestamp, now);
latency_averager_->add(latency_ms);
auto_buff_debug_.latency_ms = latency_averager_->average();
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
static double last_unwrapped_roll = 0.0;
static double last_raw_roll = 0.0;
const double raw_roll = rune_target.roll();
const double raw_pred = rune_target.predictAngle(0.5);
const double obs_angle = last_unwrapped_roll
+ angles::shortest_angular_distance(last_raw_roll, raw_roll);
const double pre_angle =
obs_angle + angles::shortest_angular_distance(raw_roll, raw_pred);
last_unwrapped_roll = obs_angle;
last_raw_roll = raw_roll;
auto_buff_debug_.obs_v = rune_target.v_roll();
auto_buff_debug_.fitter_v = rune_target.getFitterSpd(
wust_vl::common::utils::time_utils::now()
+ std::chrono::microseconds(int(0.2 * 1e6))
);
auto_buff_debug_.obs_angle = obs_angle;
auto_buff_debug_.pre_angle = pre_angle;
auto_buff_debug_.target = rune_target;
auto_buff_debug_.power_rune = rune_target.getPowerRune();
}
}
GimbalCmd solve(double bullet_speed) {
GimbalCmd gimbal_cmd;
auto_buff::RuneTarget rune_target;
{
std::lock_guard<std::mutex> lock(target_mutex_);
rune_target = rune_target_;
}
if (rune_target.checkTargetAppear()) {
gimbal_cmd = aimer_->aim(rune_target, bullet_speed);
}
if (gimbal_cmd.fire_advice) {
fire_count_++;
}
if (debug_mode_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
auto_buff_debug_.gimbal_cmd = gimbal_cmd;
auto_buff_debug_.aim_target = gimbal_cmd.aim_target;
}
timer_cout_++;
return gimbal_cmd;
}
void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) {
while (!self->isAlive()) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
while (self->isAlive() && run_flag_) {
if (!self->waitPoint())
break;
self->heartbeat();
printStats();
auto_buff::RuneFan auto_buff;
// bool skip;
// if (rune_queue_->dequeue_wait(auto_buff, skip)) {
// runeTargetCallback(auto_buff);
// tracker_finish_count_++;
// if (skip) {
// WUST_DEBUG(logger_) << "OrderQueue skip";
// }
// }
if (rune_queue_->try_dequeue(auto_buff)) {
runeTargetCallback(auto_buff);
tracker_finish_count_++;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
}
void doDebug() {
debug_mode_ = true;
AutoBuffDebug dbg;
{
std::lock_guard<std::mutex> lock(dbg_mutex_);
dbg = auto_buff_debug_;
}
drawDebugOverlayShm(dbg, camera_info_, false);
debuglog(dbg);
}
void printStats() {
utils::XSecOnce(
[&] {
WUST_INFO(logger_)
<< "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_
<< ", Fin: " << tracker_finish_count_ << ", Tc: " << timer_cout_
<< ", Lat: " << auto_buff_debug_.latency_ms << "ms"
<< ", Fire: " << fire_count_;
img_recv_count_ = 0;
detect_finish_count_ = 0;
fire_count_ = 0;
tracker_finish_count_ = 0;
timer_cout_ = 0;
},
1.0
);
}
void
autoExposureControl(const cv::Mat& frame, std::shared_ptr<wust_vl::video::Camera> camera) {
const double dt = auto_exposure_cfg_->control_interval_ms_param.get() / 1000.0;
utils::XSecOnce(
[&] {
if (!auto_exposure_cfg_->enable_param.get() || frame.empty()) {
return;
}
if (auto* hik = dynamic_cast<wust_vl::video::HikCamera*>(camera->getDevice())) {
cv::Mat i_use = frame(expanded_);
if (expanded_.area() < 100 || i_use.empty()) {
i_use = frame;
}
const double brightness = utils::computeBrightness(i_use);
const double diff =
brightness - auto_exposure_cfg_->target_brightness_param.get();
const double exposure_min = auto_exposure_cfg_->exposure_min_param.get();
const double exposure_max = auto_exposure_cfg_->exposure_max_param.get();
double exposure_time = hik->getExposureTime();
const double last_exposure_time = exposure_time;
if (std::fabs(diff) > auto_exposure_cfg_->tolerance_param.get()
&& exposure_time > 0.0) {
exposure_time -= diff * auto_exposure_cfg_->step_gain_param.get();
} else {
exposure_time -= auto_exposure_cfg_->decay_step_param.get();
}
if (exposure_time < exposure_min)
exposure_time = exposure_min;
if (exposure_time > exposure_max)
exposure_time = exposure_max;
if (std::abs(exposure_time - last_exposure_time) > 10) {
hik->setExposureTime(exposure_time);
}
}
},
dt
);
}
std::mutex callback_mutex_;
RuneDetectorCV::Ptr rune_detector_;
RuneTracker::Ptr rune_tracker_;
auto_buff::Aimer::Ptr aimer_;
RuneWhere::Ptr rune_where_;
std::string logger_ = "auto_buff";
std::unique_ptr<wust_vl::common::concurrency::OrderedQueue<auto_buff::RuneFan>> rune_queue_;
wust_vl::common::concurrency::MonitoredThread::Ptr processing_thread_;
AutoExposureCfg::Ptr auto_exposure_cfg_;
cv::Rect expanded_;
auto_buff::RuneTarget rune_target_;
bool run_flag_ = false;
int detect_finish_count_ = 0;
int img_recv_count_ = 0;
int tracker_finish_count_ = 0;
int timer_cout_ = 0;
int fire_count_;
std::chrono::steady_clock::time_point last_rune_target_time_;
bool debug_mode_ = false;
AutoBuffDebug auto_buff_debug_;
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
TFConfig::Ptr tf_config_;
std::pair<cv::Mat, cv::Mat> camera_info_;
wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter_;
Eigen::Matrix4d T_camera_to_odom_;
std::mutex target_mutex_;
std::mutex dbg_mutex_;
};
AutoBuff::AutoBuff(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
):
_impl(std::make_unique<Impl>(config_path, tf_config, camera_info, debug)) {}
AutoBuff::~AutoBuff() {
_impl.reset();
}
void AutoBuff::start() {
_impl->start();
}
void AutoBuff::pushInput(CommonFrame& frame) {
_impl->pushInput(frame);
}
GimbalCmd AutoBuff::solve(double bullet_speed) {
return _impl->solve(bullet_speed);
}
wust_vl::common::concurrency::MonitoredThread::Ptr AutoBuff::getThread() {
return _impl->processing_thread_;
}
void AutoBuff::doDebug() {
_impl->doDebug();
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,38 @@
#pragma once
#include "tasks/imodule.hpp"
#include "tasks/type_common.hpp"
#include "wust_vl/video/camera.hpp"
namespace wust_vision {
namespace auto_buff {
class AutoBuff: public IModule {
public:
using Ptr = std::shared_ptr<AutoBuff>;
AutoBuff(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
);
static Ptr create(
const std::string& config_path,
TFConfig::Ptr tf_config,
const std::pair<cv::Mat, cv::Mat>& camera_info,
bool debug
) {
return std::make_shared<AutoBuff>(config_path, tf_config, camera_info, debug);
}
~AutoBuff();
void start() override;
void doDebug() override;
void pushInput(CommonFrame& frame) override;
GimbalCmd solve(double bullet_speed) override;
wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override;
struct Impl;
std::unique_ptr<Impl> _impl;
};
inline AutoBuff::Ptr toAutoBuff(IModule::Ptr module) {
return std::dynamic_pointer_cast<AutoBuff>(module);
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,241 @@
#include "debug.hpp"
#include "tasks/auto_buff/auto_buff.hpp"
namespace wust_vision::auto_buff {
void drawDebugRuneContent(
cv::Mat& debug_img,
const AutoBuffDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info
) {
const auto& gimbal_cmd = dbg.gimbal_cmd;
double predict_angle = dbg.predict_angle;
auto aim_target = dbg.aim_target;
auto auto_buff = dbg.power_rune;
const cv::Rect img_rect(0, 0, debug_img.cols, debug_img.rows);
const cv::Rect roi = dbg.expanded & img_rect;
cv::rectangle(debug_img, roi, cv::Scalar(255, 255, 255), 2);
const std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms);
cv::putText(
debug_img,
latency_str,
cv::Point(10, 30),
cv::FONT_HERSHEY_SIMPLEX,
0.8,
cv::Scalar(255, 255, 255),
2
);
aim_target.tf(dbg.T_camera_to_odom.inverse());
{
const auto pts = aim_target.toPts(camera_info.first, camera_info.second);
if (!pts.empty()) {
cv::Scalar color = cv::Scalar(255, 255, 255);
for (int i = 0; i < 4; i++)
cv::line(debug_img, pts[i], pts[(i + 1) % 4], color, 2);
// 后表面
for (int i = 4; i < 8; i++)
cv::line(debug_img, pts[i], pts[4 + (i + 1) % 4], color, 2);
// 侧边
for (int i = 0; i < 4; i++)
cv::line(debug_img, pts[i], pts[i + 4], color, 2);
cv::Point2f center(0.f, 0.f);
for (auto pt: pts) {
center += pt;
}
center *= 1.0 / pts.size();
if (gimbal_cmd.fire_advice) {
int cross_len = 60;
cv::line(
debug_img,
center + cv::Point2f(-cross_len, -cross_len),
center + cv::Point2f(+cross_len, +cross_len),
cv::Scalar(0, 0, 255),
5
);
cv::line(
debug_img,
center + cv::Point2f(-cross_len, +cross_len),
center + cv::Point2f(+cross_len, -cross_len),
cv::Scalar(0, 0, 255),
5
);
}
const double scale = 10.0;
const double v_yaw = gimbal_cmd.v_yaw;
const double v_pitch = gimbal_cmd.v_pitch;
const double dx = -scale * v_yaw;
const double dy = scale * v_pitch;
const cv::Point2f start_pt = center;
const cv::Point2f end_pt = start_pt + cv::Point2f(dx, dy);
const cv::Scalar color_x = cv::Scalar(50, 50, 255);
cv::arrowedLine(debug_img, start_pt, end_pt, color_x, 4, cv::LINE_AA, 0, 0.2);
}
}
if (gimbal_cmd.fire_advice) {
const std::string fire_str = "Fire!";
cv::putText(
debug_img,
fire_str,
{ debug_img.cols / 2 - 100, 200 },
cv::FONT_HERSHEY_SIMPLEX,
2.85,
cv::Scalar(0, 0, 255),
2
);
}
const std::string gimbal_str = fmt::format(
"Pitch: {:.2f}, Yaw: {:.2f}, Enable_pitch_diff: {:.2f}, Enable_yaw_diff: {:.2f}, V_yaw: {:.2f}, V_pitch: {:.2f}",
gimbal_cmd.pitch,
gimbal_cmd.yaw,
gimbal_cmd.enable_pitch_diff,
gimbal_cmd.enable_yaw_diff,
gimbal_cmd.v_yaw,
gimbal_cmd.v_pitch
);
cv::putText(
debug_img,
gimbal_str,
{ 10, debug_img.rows - 30 },
cv::FONT_HERSHEY_SIMPLEX,
0.8,
cv::Scalar(255, 255, 0),
2
);
auto_buff.tf(dbg.T_camera_to_odom.inverse());
auto_buff.draw(debug_img, camera_info.first, camera_info.second);
cv::circle(
debug_img,
cv::Point2i(debug_img.cols / 2, debug_img.rows / 2),
5,
cv::Scalar(255, 255, 255),
2
);
}
void writeTargetLogToJson(const auto_buff::RuneTarget& rune_target) {
nlohmann::json j;
const auto now = std::chrono::steady_clock::now();
nlohmann::json jr;
jr["tracking"] = true;
jr["id"] = static_cast<int>(rune_target.last_id);
const auto age_ms_r =
std::chrono::duration_cast<std::chrono::milliseconds>(now - rune_target.timestamp_).count();
jr["timestamp_age_ms"] = age_ms_r;
jr["position"] = { { "x", rune_target.centerPos().x() },
{ "y", rune_target.centerPos().y() },
{ "z", rune_target.centerPos().z() } };
jr["roll"] = rune_target.roll() * 180.0 / M_PI;
jr["yaw"] = rune_target.yaw() * 180.0 / M_PI;
jr["v_roll"] = rune_target.v_roll() * 180.0 / M_PI;
j["rune_target"] = jr;
// -------- 写文件 --------
std::ofstream file("/dev/shm/target_log.json");
if (file.is_open()) {
file << j.dump(2);
}
}
struct DebugLogs {
#define DEBUG_LOG_LIST(X) \
X(double, 100, time) \
X(double, 100, raw_yaw) \
X(double, 100, raw_pitch) \
X(double, 100, yaw) \
X(double, 100, pitch) \
X(double, 100, ypd_y) \
X(double, 100, ypd_p) \
X(double, 100, rune_obs) \
X(double, 100, rune_pre) \
X(double, 100, rune_obsv) \
X(double, 100, rune_fitv) \
X(double, 100, gimbal_yaw) \
X(double, 100, gimbal_pitch) \
X(double, 100, target_v_yaw) \
X(double, 100, control_v_yaw) \
X(double, 100, control_v_pitch) \
X(double, 100, yaw_diff) \
X(double, 100, fire) \
X(double, 100, rune_dis) \
X(double, 100, fly_time) \
X(double, 100, control_a_yaw) \
X(double, 100, control_a_pitch)
#define GEN_LOG(TYPE, SIZE, NAME) LogsStream<TYPE, SIZE> NAME##_log { #NAME };
#define X(TYPE, SIZE, NAME) GEN_LOG(TYPE, SIZE, NAME)
DEBUG_LOG_LIST(X)
#undef X
void clear() {
#define X(TYPE, SIZE, NAME) NAME##_log.clear();
DEBUG_LOG_LIST(X)
#undef X
}
};
void debuglog(const AutoBuffDebug& dbg_rune) {
static bool first_log = true;
static std::chrono::steady_clock::time_point start_time;
static DebugLogs log;
static GimbalCmd last_cmd_;
static double rune_dis = 0.0;
if (first_log) {
start_time = std::chrono::steady_clock::now();
first_log = false;
}
const auto now = std::chrono::steady_clock::now();
const double t = std::chrono::duration<double>(now - start_time).count();
const auto_buff::RuneTarget& rune_target = dbg_rune.target;
writeTargetLogToJson(rune_target);
double armor_yaw = 0.0, ypd_y = 0.0, ypd_p = 0.0, armor_distance = 0.0;
if (dbg_rune.pnp_distance > 1.0) {
rune_dis = dbg_rune.pnp_distance;
}
GimbalCmd i_use;
if (dbg_rune.gimbal_cmd.appear) {
i_use = dbg_rune.gimbal_cmd;
} else {
i_use = last_cmd_;
}
last_cmd_ = i_use;
nlohmann::json j;
log.time_log.handleOnce(t, j);
log.raw_yaw_log.handleOnce(i_use.target_yaw, j);
log.raw_pitch_log.handleOnce(i_use.target_pitch, j);
log.yaw_log.handleOnce(i_use.yaw, j);
log.pitch_log.handleOnce(i_use.pitch, j);
log.rune_obs_log.handleOnce(dbg_rune.obs_angle, j);
log.rune_pre_log.handleOnce(dbg_rune.pre_angle, j);
log.rune_fitv_log.handleOnce(dbg_rune.fitter_v * 180.0 / M_PI, j);
log.rune_obsv_log.handleOnce(dbg_rune.obs_v * 180.0 / M_PI, j);
log.gimbal_pitch_log.handleOnce(dbg_rune.gimbal_py.first * 180.0 / M_PI, j);
log.gimbal_yaw_log.handleOnce(dbg_rune.gimbal_py.second * 180.0 / M_PI, j);
log.control_v_pitch_log.handleOnce(i_use.v_pitch, j);
log.control_v_yaw_log.handleOnce(i_use.v_yaw, j);
log.fire_log.handleOnce(i_use.fire_advice, j);
log.rune_dis_log.handleOnce(rune_dis, j);
log.fly_time_log.handleOnce(i_use.fly_time, j);
log.control_a_yaw_log.handleOnce(i_use.a_yaw / 180.0 * M_PI, j);
log.control_a_pitch_log.handleOnce(i_use.a_pitch / 180.0 * M_PI, j);
std::ofstream file("/dev/shm/cmd_log.json");
if (file.is_open()) {
file << j.dump();
}
}
} // namespace wust_vision::auto_buff

View File

@@ -0,0 +1,39 @@
#pragma once
#include "tasks/auto_buff/auto_buff.hpp"
#include "tasks/auto_buff/rune_tracker/rune_target.hpp"
#include "tasks/utils/debug_utils.hpp"
#include <wust_vl/video/icamera.hpp>
namespace wust_vision::auto_buff {
struct AutoBuffDebug {
wust_vl::video::ImageFrame img_frame;
auto_buff::RuneTarget target;
AimTarget aim_target;
auto_buff::PowerRune power_rune;
double predict_angle;
GimbalCmd gimbal_cmd;
Eigen::Matrix4d T_camera_to_odom;
double latency_ms;
double obs_angle;
double pre_angle;
double fitter_v;
double obs_v;
cv::Rect expanded;
double pnp_distance;
std::pair<double, double> gimbal_py;
};
void drawDebugRuneContent(
cv::Mat& debug_img,
const AutoBuffDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info
);
void writeTargetLogToJson(const auto_buff::RuneTarget& rune_target);
inline void drawDebugOverlayShm(
const AutoBuffDebug& dbg,
std::pair<cv::Mat, cv::Mat> camera_info,
bool auto_fps
) {
static ShmWriter shm { "/debug_frame" };
drawDebugOverlayImpl(dbg, camera_info, auto_fps, drawDebugRuneContent, shm);
}
void debuglog(const AutoBuffDebug& dbg);
} // namespace wust_vision::auto_buff

View File

@@ -0,0 +1,237 @@
#include "aimer.hpp"
#include "tasks/auto_buff/rune_tracker/rune_target.hpp"
#include "wust_vl/common/utils/manual_compensator.hpp"
namespace wust_vision {
namespace auto_buff {
struct Aimer::Impl {
public:
struct AimerConfig: wust_vl::common::utils::ParamGroup {
static constexpr const char* Logger = "Config: auto_buff::aimer";
static constexpr const char* kKey = "aimer";
const char* key() const override {
return kKey;
}
using Ptr = std::shared_ptr<AimerConfig>;
AimerConfig() {}
static Ptr create() {
return std::make_shared<AimerConfig>();
}
std::shared_ptr<wust_vl::common::utils::ManualCompensator> manual_compensator;
GEN_PARAM(double, prediction_delay);
GEN_PARAM(double, shooting_range_h);
GEN_PARAM(double, shooting_range_w);
GEN_PARAM(double, min_enable_pitch_deg);
GEN_PARAM(double, min_enable_yaw_deg);
bool first_load = false;
using OffsetEntry = wust_vl::common::utils::OffsetEntry;
static std::vector<OffsetEntry>
parseTrajectoryOffset(const YAML::Node& node, double& base_pitch, double& base_yaw) {
std::vector<OffsetEntry> entries;
if (!node || !node["trajectory_offset"]) {
return entries;
}
for (const auto& n: node["trajectory_offset"]) {
entries.push_back(OffsetEntry { .d_min = n["d_min"].as<double>(),
.d_max = n["d_max"].as<double>(),
.h_min = n["h_min"].as<double>(),
.h_max = n["h_max"].as<double>(),
.pitch_off = n["pitch_off"].as<double>(),
.yaw_off = n["yaw_off"].as<double>() });
}
base_pitch = node["base_offset"]["pitch"].as<double>();
base_yaw = node["base_offset"]["yaw"].as<double>();
return entries;
}
std::vector<OffsetEntry> last_entries_;
double last_base_pitch_ = 0.0;
double last_base_yaw_ = 0.0;
bool first_load_ = true;
void loadSelf(const YAML::Node& node) override {
double base_pitch = 0.0;
double base_yaw = 0.0;
auto entries = parseTrajectoryOffset(node, base_pitch, base_yaw);
const bool trajectory_changed = first_load_ || entries != last_entries_
|| base_pitch != last_base_pitch_ || base_yaw != last_base_yaw_;
if (trajectory_changed) {
manual_compensator =
std::make_shared<wust_vl::common::utils::ManualCompensator>();
manual_compensator->setBasePitch(base_pitch);
manual_compensator->setBaseYaw(base_yaw);
if (!manual_compensator->updateMapFlow(entries) || entries.empty()) {
std::cout << "manual_compensator init failed" << std::endl;
}
last_entries_ = entries;
last_base_pitch_ = base_pitch;
last_base_yaw_ = base_yaw;
first_load_ = false;
}
shooting_range_h_param.load(node);
shooting_range_w_param.load(node);
min_enable_pitch_deg_param.load(node);
min_enable_yaw_deg_param.load(node);
prediction_delay_param.load(node);
}
};
Impl(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) {
aimer_config_ = AimerConfig::create();
trajectory_compensator_config_ = TrajectoryCompensatorConfig::create();
auto_buff_config_parameter->registerGroup(*aimer_config_);
auto_buff_config_parameter->registerGroup(*trajectory_compensator_config_);
auto_buff_config_parameter->reloadFromOldPath();
}
std::tuple<double, double> calEnableDiff(Eigen::Vector3d aim_target_pos) const noexcept {
const double distance = aim_target_pos.norm();
double shooting_range_yaw =
std::abs(atan2(aimer_config_->shooting_range_w_param.get() / 2, distance));
double shooting_range_pitch =
std::abs(atan2(aimer_config_->shooting_range_h_param.get() / 2, distance));
constexpr double yaw_factor = 1.0;
constexpr double pitch_factor = 1.0;
shooting_range_yaw = std::max(
shooting_range_yaw,
aimer_config_->min_enable_yaw_deg_param.get() * M_PI / 180
);
shooting_range_pitch = std::max(
shooting_range_pitch,
aimer_config_->min_enable_pitch_deg_param.get() * M_PI / 180
);
shooting_range_yaw *= yaw_factor;
shooting_range_pitch *= pitch_factor;
return std::make_tuple(std::abs(shooting_range_yaw), std::abs(shooting_range_pitch));
}
struct ControlPoint {
double yaw;
double pitch;
Eigen::Vector3d aim_pos;
};
ControlPoint getControlPoint(RuneTarget target, double bullet_speed) const noexcept {
auto [aim_target_pos, _] = target.getHitPoint();
ControlPoint cp;
double control_yaw = std::atan2(aim_target_pos.y(), aim_target_pos.x());
double raw_pitch = std::atan2(
aim_target_pos.z(),
std::sqrt(
aim_target_pos.x() * aim_target_pos.x()
+ aim_target_pos.y() * aim_target_pos.y()
)
);
try {
trajectory_compensator_config_->trajectory_compensator
->compensate(aim_target_pos, raw_pitch, bullet_speed);
} catch (std::exception& e) {
std::cout << "compensate error: " << e.what() << std::endl;
}
double control_pitch = raw_pitch;
const auto offs = aimer_config_->manual_compensator->angleHardCorrect(
aim_target_pos.head(2).norm(),
aim_target_pos.z()
);
control_yaw = angles::normalize_angle((control_yaw + offs[1] * M_PI / 180.0));
control_pitch = (control_pitch + offs[0] * M_PI / 180.0);
cp.pitch = control_pitch;
cp.yaw = control_yaw;
cp.aim_pos = aim_target_pos;
return cp;
}
GimbalCmd aim(RuneTarget target, double bullet_speed) {
GimbalCmd cmd;
// 当前时间
const auto now = wust_vl::common::utils::time_utils::now();
// 首次预测目标位置
target.predictWithFitter(now);
auto [p0, q0] = target.getHitPoint();
// 迭代计算飞行时间
bool converged = false;
double prev_fly_time = trajectory_compensator_config_->trajectory_compensator
->getFlyingTime(p0, bullet_speed);
std::vector<RuneTarget> iteration_target(10, target);
for (int iter = 0; iter < 10; ++iter) {
iteration_target[iter].predictWithFitter(prev_fly_time);
auto [pb, qb] = iteration_target[iter].getHitPoint();
double iter_fly_time = trajectory_compensator_config_->trajectory_compensator
->getFlyingTime(pb, bullet_speed);
if (std::abs(iter_fly_time - prev_fly_time) < 0.001) {
converged = true;
break;
}
prev_fly_time = iter_fly_time;
}
const double predict_time = prev_fly_time + aimer_config_->prediction_delay_param.get();
target.predictWithFitter(predict_time);
const auto cp = getControlPoint(target, bullet_speed);
// RuneTarget target_prev, target_next = target;
RuneTarget target_prev = target;
RuneTarget target_next = target;
const double dt = 0.01;
target_prev.predictWithFitter(-dt);
auto cp_prev = getControlPoint(target_prev, bullet_speed);
target_next.predictWithFitter(dt);
auto cp_next = getControlPoint(target_next, bullet_speed);
double yaw_speed = (cp_next.yaw - cp_prev.yaw) / (2.0 * dt);
double pitch_speed = (cp_next.pitch - cp_prev.pitch) / (2.0 * dt);
double yaw_acc = (cp_next.yaw - 2.0 * cp.yaw + cp_prev.yaw) / (dt * dt);
double pitch_acc = (cp_next.pitch - 2.0 * cp.pitch + cp_prev.pitch) / (dt * dt);
AimTarget aim_target;
aim_target.pos = cp.aim_pos;
// 填充 GimbalCmd
cmd.distance = cp.aim_pos.norm();
cmd.aim_target = aim_target;
cmd.yaw = cp.yaw * 180.0 / M_PI;
cmd.pitch = cp.pitch * 180.0 / M_PI;
// cmd.v_yaw = yaw_speed * 180.0 / M_PI; // 转为度/秒
// cmd.v_pitch = pitch_speed * 180.0 / M_PI;
// cmd.a_yaw = yaw_acc * 180.0 / M_PI; // 转为度/秒²
// cmd.a_pitch = pitch_acc * 180.0 / M_PI;
cmd.v_yaw = 0.0; // 转为度/秒
cmd.v_pitch = 0.0;
cmd.a_yaw = 0.0; // 转为度/秒²
cmd.a_pitch = 0.0;
const auto [enable_yaw, enable_pitch] = calEnableDiff(cp.aim_pos);
cmd.fire_advice = true;
cmd.enable_yaw_diff = enable_yaw;
cmd.enable_pitch_diff = enable_pitch;
cmd.target_yaw = cp.yaw * 180.0 / M_PI;
cmd.target_pitch = cp.pitch * 180.0 / M_PI;
cmd.fly_time = prev_fly_time;
cmd.appear = true;
return cmd;
}
TrajectoryCompensatorConfig::Ptr trajectory_compensator_config_;
AimerConfig::Ptr aimer_config_;
};
Aimer::Aimer(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) {
_impl = std::make_unique<Impl>(auto_buff_config_parameter);
}
Aimer::~Aimer() {
_impl.reset();
}
GimbalCmd Aimer::aim(const auto_buff::RuneTarget& target, double bullet_speed) {
return _impl->aim(target, bullet_speed);
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,22 @@
#pragma once
#include "tasks/auto_buff/rune_tracker/rune_target.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_buff {
class Aimer {
public:
using Ptr = std::unique_ptr<Aimer>;
Aimer(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter);
static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) {
return std::make_unique<Aimer>(auto_buff_config_parameter);
}
~Aimer();
GimbalCmd aim(const auto_buff::RuneTarget& target, double bullet_speed);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,497 @@
#include "rune_detector.hpp"
#include "tasks/utils/utils.hpp"
#include <opencv2/highgui.hpp>
namespace wust_vision {
namespace auto_buff {
struct RuneDetectorCV::Impl {
public:
Impl(const YAML::Node& node) {
params_.load(node);
}
void setCallback(DetectorCallback callback) {
callback_ = callback;
}
cv::Mat preProcess(const cv::Mat& src, bool use_red = false) {
// cv::Mat bin;
// cv::cvtColor(src, bin, cv::COLOR_RGB2GRAY);
// cv::threshold(bin, bin, params_.bin_threshold, 255, cv::THRESH_BINARY);
std::vector<cv::Mat> channels;
cv::split(src, channels); // BGR
cv::Mat diff;
if (use_red) {
cv::subtract(channels[2], channels[0], diff); // R - B
} else {
cv::subtract(channels[0], channels[2], diff); // B - R
}
cv::Mat bin;
cv::threshold(diff, bin, params_.bin_threshold, 255, cv::THRESH_BINARY);
// cv::imshow("bin", bin);
// cv::waitKey(1);
return bin;
}
inline auto_buff::RuneCenter getRuneCenter(
const std::vector<std::vector<cv::Point>>& contours,
const std::vector<cv::Vec4i>& hierarchy,
cv::Size image_size,
cv::Point2f offset,
cv::Mat& debug_img,
std::vector<bool>& used_flags
) {
auto_buff::RuneCenter result;
struct Node {
cv::Point2f center;
int idx;
cv::RotatedRect rr;
};
std::vector<Node> nodes;
for (int i = 0; i < contours.size(); i++) {
if (used_flags[i])
continue;
if (hierarchy[i][3] != -1)
continue;
double area = cv::contourArea(contours[i]);
if (area < params_.rune_center_min_area || area > params_.rune_center_max_area)
continue;
cv::RotatedRect rr = cv::minAreaRect(contours[i]);
float w = rr.size.width;
float h = rr.size.height;
if (w < 5 || h < 5)
continue;
double ratio = (w > h ? w / h : h / w);
if (ratio - 1.0 > params_.rune_center_1x1ratio_tol)
continue;
double rect_area = w * h;
if (rect_area <= 1e-5)
continue;
double fill_ratio = area / rect_area;
if (fill_ratio < params_.rune_center_fill_ratio_min)
continue;
nodes.push_back({ rr.center, i, rr });
if (!debug_img.empty()) {
cv::Point2f pts[4];
rr.points(pts);
for (size_t k = 0; k < 4; k++) {
pts[k] += offset;
}
for (int k = 0; k < 4; k++) {
cv::line(debug_img, pts[k], pts[(k + 1) % 4], cv::Scalar(0, 255, 0), 2);
}
}
}
if (nodes.empty())
return result;
cv::Point2f img_center(image_size.width * 0.5f, image_size.height * 0.5f);
double best_dist = 1e18;
int best_idx = -1;
cv::RotatedRect best_rr;
for (auto& n: nodes) {
double dx = n.center.x - img_center.x;
double dy = n.center.y - img_center.y;
double dist2 = dx * dx + dy * dy;
if (dist2 < best_dist) {
best_dist = dist2;
best_idx = n.idx;
best_rr = n.rr;
}
}
if (!debug_img.empty()) {
cv::circle(
debug_img,
img_center + offset,
5,
cv::Scalar(0, 255, 255),
-1
); // 图像中心
cv::Point2f pts[4];
best_rr.points(pts);
for (size_t k = 0; k < 4; k++) {
pts[k] += offset;
}
for (int k = 0; k < 4; k++) {
cv::line(debug_img, pts[k], pts[(k + 1) % 4], cv::Scalar(0, 0, 255), 2);
}
}
return auto_buff::RuneCenter(best_rr);
}
inline int findTopParent(int idx, const std::vector<cv::Vec4i>& hierarchy) {
int p = hierarchy[idx][3]; // parent
while (p != -1 && hierarchy[p][3] != -1) {
p = hierarchy[p][3]; // 一直追溯到最顶层 parent
}
return p; // 若 p == -1 表示 contour 本身就是顶层轮廓
}
inline std::vector<auto_buff::RunePan> markRuneTarget(
const std::vector<std::vector<cv::Point>>& contours,
const std::vector<cv::Vec4i>& hierarchy,
std::vector<bool>& used_flags
) {
std::vector<auto_buff::RunePan> results;
if (hierarchy.empty())
return results;
struct Node {
int idx;
cv::Point2f center;
int parent_top_id;
};
std::vector<Node> candidates;
for (int i = 0; i < contours.size(); i++) {
if (used_flags[i])
continue;
const auto& cnt = contours[i];
double contour_area = cv::contourArea(cnt);
if (contour_area < params_.rune_target_min_area
|| contour_area > params_.rune_target_max_area)
continue;
cv::Moments m = cv::moments(cnt);
if (m.m00 == 0)
continue;
cv::Point2f center(m.m10 / m.m00, m.m01 / m.m00);
int top_parent = findTopParent(i, hierarchy);
candidates.push_back({ i, center, top_parent });
}
if (candidates.size() < 3)
return results;
std::unordered_map<int, std::vector<int>> groups;
for (int i = 0; i < candidates.size(); i++) {
groups[candidates[i].parent_top_id].push_back(i);
}
for (auto& [parent_top_id, idx_list]: groups) {
int M = idx_list.size();
if (M < 3 || M > 7)
continue;
std::vector<int> cluster_id(M, -1);
int cluster_count = 0;
for (int i = 0; i < M; i++) {
if (cluster_id[i] != -1)
continue;
cluster_id[i] = cluster_count;
std::queue<int> q;
q.push(i);
while (!q.empty()) {
int u = q.front();
q.pop();
for (int v = 0; v < M; v++) {
if (cluster_id[v] != -1)
continue;
auto& cu = candidates[idx_list[u]].center;
auto& cv = candidates[idx_list[v]].center;
double dx = cu.x - cv.x;
double dy = cu.y - cv.y;
double dist = std::sqrt(dx * dx + dy * dy);
if (dist <= params_.rune_target_cluster_radius) {
cluster_id[v] = cluster_count;
q.push(v);
}
}
}
cluster_count++;
}
std::vector<int> cluster_size(cluster_count, 0);
for (int id: cluster_id)
cluster_size[id]++;
std::vector<std::vector<cv::Point2f>> cluster_points(cluster_count);
for (int i = 0; i < M; i++) {
int cid = cluster_id[i];
if (cluster_size[cid] >= 3) {
int contour_index = candidates[idx_list[i]].idx;
used_flags[contour_index] = true;
cluster_points[cid].push_back(candidates[idx_list[i]].center);
}
}
for (int cid = 0; cid < cluster_count; cid++) {
if (cluster_points[cid].size() < 3)
continue;
cv::RotatedRect rr = cv::minAreaRect(cluster_points[cid]);
double w = rr.size.width;
double h = rr.size.height;
if (w < 1 || h < 1)
continue;
double ratio = (w > h ? w / h : h / w);
if (ratio > params_.rune_target_max_square_ratio)
continue;
std::vector<std::pair<double, cv::Point2f>> dist_list;
dist_list.reserve(cluster_points[cid].size());
for (auto& p: cluster_points[cid]) {
double dx = p.x - rr.center.x;
double dy = p.y - rr.center.y;
double dist = dx * dx + dy * dy;
dist_list.emplace_back(dist, p);
}
std::sort(dist_list.begin(), dist_list.end(), [](auto& a, auto& b) {
return a.first > b.first;
});
std::vector<cv::Point2f> corner_points;
for (int i = 0; i < 4 && i < dist_list.size(); i++)
corner_points.push_back(dist_list[i].second);
auto_buff::RunePan pan;
pan.center = rr.center;
pan.corners = corner_points;
if (corner_points.size() > 3)
pan.is_valid = true;
results.push_back(pan);
}
}
return results;
}
inline void markInvalidContours(
cv::Mat& color,
cv::Mat& debug_img,
const std::vector<std::vector<cv::Point>>& contours,
std::vector<bool>& used_flags,
const cv::Rect& valid_rect,
bool filter_red,
double diff_thresh
) {
used_flags.assign(contours.size(), false);
for (int i = 0; i < contours.size(); i++) {
cv::Rect r = cv::boundingRect(contours[i]);
if (r.width < 5 || r.height < 5)
continue;
cv::Rect rr = r & cv::Rect(0, 0, color.cols, color.rows);
if (rr.width < 2 || rr.height < 2)
continue;
const cv::Mat roi = color(rr);
const cv::Scalar avg = cv::mean(roi);
const double B = avg[0], G = avg[1], R = avg[2];
const double diff_RB = R - B;
const double diff_BR = B - R;
const bool is_red = (diff_RB > diff_thresh);
const bool is_blue = (diff_BR > diff_thresh);
bool invalid = false;
if (filter_red) {
if (is_red)
invalid = true;
} else {
if (is_blue)
invalid = true;
}
cv::Rect inter = r & valid_rect;
bool inside_region = (inter.area() > 0);
used_flags[i] = !invalid || !inside_region;
if (!used_flags[i]) {
if (!debug_img.empty())
cv::drawContours(debug_img, contours, i, cv::Scalar(255, 0, 0), 2);
}
}
}
static bool isUpscaled(const cv::Rect& roi, int model_w, int model_h) {
float scale = std::min(model_w / float(roi.width), model_h / float(roi.height));
return scale > 1.0f;
}
void pushInput(CommonFrame& frame, bool debug) {
frame.id = current_id_++;
auto_buff::RuneFan fan {
.is_valid = false,
.id = frame.id,
.timestamp = frame.img_frame.timestamp,
};
cv::Mat debug_img;
if (debug) {
debug_img = frame.img_frame.src_img.clone();
}
cv::Mat roi = frame.img_frame.src_img(frame.expanded);
cv::Mat processed_img = preProcess(roi, frame.detect_color);
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> hierarchy;
cv::findContours(
processed_img,
contours,
hierarchy,
cv::RETR_TREE,
cv::CHAIN_APPROX_SIMPLE
);
std::vector<bool> used_flags;
used_flags.assign(contours.size(), false);
markInvalidContours(
roi,
debug_img,
contours,
used_flags,
cv::Rect(0, 0, roi.cols, roi.rows),
frame.detect_color,
params_.color_diff_threshold
);
auto rune_center =
getRuneCenter(contours, hierarchy, roi.size(), frame.offset, debug_img, used_flags);
std::vector<auto_buff::RunePan> rune_pans =
markRuneTarget(contours, hierarchy, used_flags);
for (auto& rune_pan: rune_pans) {
if (rune_center.is_valid) {
rune_pan.addReferRuneCenter(rune_center);
}
if (rune_pan.is_valid && rune_pan.has_refer) {
auto_buff::RuneFan::Simple simple;
simple.points2d.push_back(rune_center.center);
for (auto& pt: rune_pan.corners) {
simple.points2d.push_back(pt);
}
simple.points2d.push_back(rune_pan.center);
fan.fans.push_back(simple);
}
if (!debug_img.empty())
rune_pan.draw(debug_img, frame.offset);
}
auto_buff::RuneFan tmp = fan;
for (int i = 0; i < tmp.fans.size(); i++) {
for (int j = 0; j < tmp.fans.size(); j++) {
if (i == j)
continue;
fan.fans[i].addOther(tmp.fans[j]);
}
}
fan.addOffset(frame.offset);
if (callback_) {
callback_(fan, frame, debug_img);
}
}
DetectorCallback callback_;
cv::Mat tmp_R_;
int current_id_ = 0;
struct Params {
double rune_center_min_area = 100.0;
double rune_center_max_area = 2000.0;
double rune_center_1x1ratio_tol = 0.7;
double rune_center_fill_ratio_min = 0.7;
double rune_target_min_area = 100.0;
double rune_target_max_area = 3000.0;
double rune_target_max_square_ratio = 1.3;
double rune_target_cluster_radius = 70.0;
double bin_threshold = 150.0;
double color_diff_threshold = 40.0;
int target_width = 416;
int target_height = 416;
void load(const YAML::Node& node) {
// center params
rune_center_min_area = node["rune_center_min_area"]
? node["rune_center_min_area"].as<double>()
: rune_center_min_area;
rune_center_max_area = node["rune_center_max_area"]
? node["rune_center_max_area"].as<double>()
: rune_center_max_area;
rune_center_1x1ratio_tol = node["rune_center_1x1ratio_tol"]
? node["rune_center_1x1ratio_tol"].as<double>()
: rune_center_1x1ratio_tol;
rune_center_fill_ratio_min = node["rune_center_fill_ratio_min"]
? node["rune_center_fill_ratio_min"].as<double>()
: rune_center_fill_ratio_min;
// target params
rune_target_min_area = node["rune_target_min_area"]
? node["rune_target_min_area"].as<double>()
: rune_target_min_area;
rune_target_max_area = node["rune_target_max_area"]
? node["rune_target_max_area"].as<double>()
: rune_target_max_area;
rune_target_max_square_ratio = node["rune_target_max_square_ratio"]
? node["rune_target_max_square_ratio"].as<double>()
: rune_target_max_square_ratio;
rune_target_cluster_radius = node["rune_target_cluster_radius"]
? node["rune_target_cluster_radius"].as<double>()
: rune_target_cluster_radius;
bin_threshold =
node["bin_threshold"] ? node["bin_threshold"].as<double>() : bin_threshold;
color_diff_threshold = node["color_diff_threshold"]
? node["color_diff_threshold"].as<double>()
: color_diff_threshold;
target_width = node["target_width"] ? node["target_width"].as<int>() : target_width;
target_height =
node["target_height"] ? node["target_height"].as<int>() : target_height;
}
} params_;
};
RuneDetectorCV::RuneDetectorCV(const YAML::Node& node) {
_impl = std::make_unique<Impl>(node);
}
RuneDetectorCV::~RuneDetectorCV() {
_impl.reset();
}
void RuneDetectorCV::pushInput(CommonFrame& frame, bool debug) {
_impl->pushInput(frame, debug);
}
void RuneDetectorCV::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,24 @@
#pragma once
#include "tasks/auto_buff/type.hpp"
namespace wust_vision {
namespace auto_buff {
class RuneDetectorCV {
public:
using DetectorCallback =
std::function<void(const auto_buff::RuneFan&, const CommonFrame&, cv::Mat&)>;
using Ptr = std::unique_ptr<RuneDetectorCV>;
RuneDetectorCV(const YAML::Node& node);
static inline std::unique_ptr<RuneDetectorCV> make_detector(const YAML::Node& node) {
return std::make_unique<RuneDetectorCV>(node);
}
~RuneDetectorCV();
void pushInput(CommonFrame& frame, bool debug = false);
void setCallback(DetectorCallback callback);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,71 @@
// Copyright 2025 Xiaojian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ceres/ceres.h>
#include "KalmanHyLib/kalman_hybird_lib.hpp"
namespace ypdrune_motion_model {
constexpr int X_N = 6, Z_N = 5;
using VecZ = Eigen::Matrix<double, Z_N, 1>;
using VecX = Eigen::Matrix<double, X_N, 1>;
enum class Meas : uint8_t { YPD_Y = 0, YPD_P = 1, YPD_D = 2, ORI_YAW = 3, ORI_ROLL = 4, Z_N = 5 };
enum class State : uint8_t { CX = 0, CY = 1, CZ = 2, YAW = 3, ROLL = 4, VROLL = 5, X_N = 6 };
struct Predict {
Predict() = default;
explicit Predict(double dt): dt(dt) {}
template<typename T>
void operator()(const T x0[X_N], T x1[X_N]) const {
for (int i = 0; i < X_N; ++i) {
x1[i] = x0[i];
}
x1[(int)State::ROLL] += x0[(int)State::VROLL] * dt;
}
double dt;
};
template<typename T>
T normalize_angle_t(T angle) {
T two_pi = T(2.0 * M_PI);
return angle - two_pi * floor((angle + T(M_PI)) / two_pi);
}
struct Measure {
Measure() = default;
explicit Measure(int id): id(id) {}
template<typename T>
void operator()(const T x[X_N], T z[Z_N]) const {
T xy_dist = ceres::sqrt(
x[(int)State::CX] * x[(int)State::CX] + x[(int)State::CY] * x[(int)State::CY]
);
T dist = ceres::sqrt(xy_dist * xy_dist + x[(int)State::CZ] * x[(int)State::CZ]);
// Observation model
z[(int)Meas::YPD_Y] = ceres::atan2(x[1], x[0]); // yaw
z[(int)Meas::YPD_P] = ceres::atan2(x[2], xy_dist); // pitch
z[(int)Meas::YPD_D] = dist; // distance
z[(int)Meas::ORI_YAW] = x[(int)State::YAW]; // orientation_yaw
z[(int)Meas::ORI_ROLL] = normalize_angle_t(x[(int)State::ROLL] + id * 2 * M_PI / 5); // roll
}
void h(const VecX& x, VecZ& z) const {
assert(x.size() == X_N);
assert(z.size() == Z_N);
operator()(x.data(), z.data());
}
int id = 0;
};
using RuneESKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
} // namespace ypdrune_motion_model

View File

@@ -0,0 +1,336 @@
#include "rune_target.hpp"
#include <iostream>
namespace wust_vision {
namespace auto_buff {
RuneTarget::RuneTarget(
const auto_buff::RuneFan& fan,
RuneTargetConfig::Ptr target_config,
double pre_v_roll
) {
is_big_ = false;
start_time_ = fan.timestamp;
target_config_ = target_config;
fitter_.setWindow(target_config_->big_window_sec_param.get());
auto f = MModel::Predict(0.005);
auto h = MModel::Measure(0);
auto u_q = [this]() {
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
return q;
};
auto u_r = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
return r;
};
Eigen::DiagonalMatrix<double, MModel::X_N> p0;
p0.setIdentity();
esekf_ypd_ = MModel::RuneESKF(f, h, u_q, u_r, p0);
esekf_ypd_.setResidualFunc([](const Eigen::VectorXd& z_pred, const Eigen::VectorXd& z) {
Eigen::VectorXd r = z - z_pred;
r[(int)MModel::Meas::YPD_Y] = angles::shortest_angular_distance(
z_pred[(int)MModel::Meas::YPD_Y],
z[(int)MModel::Meas::YPD_Y]
); // yaw
r[(int)MModel::Meas::ORI_YAW] = angles::shortest_angular_distance(
z_pred[(int)MModel::Meas::ORI_YAW],
z[(int)MModel::Meas::ORI_YAW]
); // ori_yaw
r[(int)MModel::Meas::ORI_ROLL] = angles::shortest_angular_distance(
z_pred[(int)MModel::Meas::ORI_ROLL],
z[(int)MModel::Meas::ORI_ROLL]
); // ori_roll
return r;
});
esekf_ypd_.setIterationNum(target_config_->esekf_iter_num_param.get());
esekf_ypd_.setInjectFunc([](const Eigen::Matrix<double, MModel::X_N, 1>& delta,
Eigen::Matrix<double, MModel::X_N, 1>& nominal) {
for (int i = 0; i < MModel::X_N; i++) {
if (i == (int)MModel::Meas::ORI_YAW || i == (int)MModel::Meas::ORI_ROLL)
continue;
nominal[i] += delta[i];
}
nominal[(int)MModel::Meas::ORI_YAW] = angles::normalize_angle(
nominal[(int)MModel::Meas::ORI_YAW] + delta[(int)MModel::Meas::ORI_YAW]
);
nominal[(int)MModel::Meas::ORI_ROLL] = angles::normalize_angle(
nominal[(int)MModel::Meas::ORI_ROLL] + delta[(int)MModel::Meas::ORI_ROLL]
);
});
double xc = fan.fans.front().target_pos.x();
double yc = fan.fans.front().target_pos.y();
double zc = fan.fans.front().target_pos.z();
double yaw = utils::orientationToYaw<RuneTarget>(fan.fans.front().target_ori);
double roll = utils::orientationToRoll<RuneTarget>(fan.fans.front().target_ori);
target_state_ = Eigen::VectorXd::Zero(MModel::X_N);
target_state_ << xc, yc, zc, yaw, roll, pre_v_roll;
esekf_ypd_.setState(target_state_);
fitter_.update(0, 0);
last_time_ = 0;
is_inited = true;
last_t_ = fan.timestamp;
timestamp_ = fan.timestamp;
}
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
RuneTarget::computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z) const {
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
// clang-format off
r << target_config_->yp_r_param.get() , 0 , 0 , 0 , 0,
0 , target_config_->yp_r_param.get() , 0 , 0 , 0,
0 , 0 , target_config_->dis_r_param.get() , 0 , 0,
0 , 0 , 0 , target_config_->yaw_r_param.get() , 0,
0 , 0 , 0 , 0 , target_config_->roll_r_param.get();
// clang-format on
return r;
}
Eigen::Matrix<double, MModel::X_N, MModel::X_N> RuneTarget::computeProcessNoise(double dt
) const {
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
double t = dt;
double v1 = target_config_->q_roll_param.get();
double q_roll_roll = pow(t, 4) / 4 * v1, q_roll_vroll = pow(t, 3) / 2 * v1,
q_vroll_vroll = pow(t, 2) * v1;
double q_xyz = target_config_->q_xyz_param.get();
double q_yaw = target_config_->q_yaw_param.get();
// clang-format off
// xc yc zc yaw roll v_roll
q << q_xyz, 0, 0, 0, 0, 0,
0, q_xyz, 0, 0, 0, 0,
0, 0, q_xyz, 0, 0, 0,
0, 0, 0, q_yaw, 0, 0,
0, 0, 0, 0, q_roll_roll, q_roll_vroll,
0, 0, 0, 0, q_roll_vroll, q_vroll_vroll;
// clang-format on
return q;
}
void RuneTarget::predict(std::chrono::steady_clock::time_point t) {
double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
predict(dt);
last_t_ = t;
}
void RuneTarget::predict(double dt) {
dt_ = dt;
esekf_ypd_.setPredictFunc(MModel::Predict { dt });
auto u_q = [dt, this]() { return computeProcessNoise(dt); };
esekf_ypd_.setUpdateQ(u_q);
target_state_ = esekf_ypd_.predict();
}
bool RuneTarget::update(const auto_buff::RuneFan& fans) {
timestamp_ = fans.timestamp;
if (fans.fans.empty()) {
return false;
}
update_ids.clear();
auto matched = match(fans.fans);
bool has_match = false;
for (auto [id, fan]: matched) {
measurement_ = getMeasure(fan);
update_ids.push_back(id);
auto yu_rv2 = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
return this->computeMeasurementCovariance(z);
};
esekf_ypd_.setUpdateR(yu_rv2);
esekf_ypd_.setMeasureFunc(MModel::Measure { id });
esekf_ypd_.update(measurement_);
if (!is_big_)
last_id = id;
has_match = true;
}
bool no_change = true;
for (auto id: update_ids) {
if (id != last_id)
no_change = false;
}
if (!no_change && update_ids.size() > 1)
last_id = update_ids[0];
// if (update_ids.size() > 1)
// is_big_ = true;
double tostart =
wust_vl::common::utils::time_utils::durationSec(start_time_, fans.timestamp);
fitter_.update(tostart, v_roll());
fitter_.setAngleRef(tostart, roll());
fitter_.fitAsync();
last_time_ = tostart;
return has_match;
}
cv::Rect RuneTarget::expanded(
Eigen::Matrix4d T_camera_to_odom,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const cv::Size& image_size
) {
double dt = wust_vl::common::utils::time_utils::durationSec(
timestamp_,
wust_vl::common::utils::time_utils::now()
);
if (!is_inited || dt > target_config_->lost_time_thres_param.get()) {
return cv::Rect(0, 0, 0, 0);
}
const float car_box_half = 1.0;
static std::vector<cv::Point3f> CAR_BOX;
CAR_BOX = { { 0, car_box_half, -car_box_half },
{ 0, -car_box_half, -car_box_half },
{ 0, -car_box_half, car_box_half },
{ 0, car_box_half, car_box_half } };
Eigen::Matrix4d T_odom_to_camera = T_camera_to_odom.inverse();
Eigen::Vector4d pos_odom(centerPos().x(), centerPos().y(), centerPos().z(), 1.0);
Eigen::Vector4d pos_cam = T_odom_to_camera * pos_odom;
if (pos_cam.z() <= 0.2) {
return cv::Rect(0, 0, 0, 0);
}
cv::Mat tvec = (cv::Mat_<double>(3, 1) << pos_cam.x(), pos_cam.y(), pos_cam.z());
Eigen::Vector3d euler;
euler.z() = M_PI / 2.0;
euler.y() = 0;
euler.x() = std::atan2(pos_odom.y(), pos_odom.x());
Eigen::Quaterniond ori = utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
auto target_ori = utils::transformOrientation(ori, T_odom_to_camera);
Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix();
cv::Mat rot_mat =
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
tf_rot(0, 1),
tf_rot(0, 2),
tf_rot(1, 0),
tf_rot(1, 1),
tf_rot(1, 2),
tf_rot(2, 0),
tf_rot(2, 1),
tf_rot(2, 2));
cv::Mat rvec;
cv::Rodrigues(rot_mat, rvec);
std::vector<cv::Point2f> pts_2d;
cv::projectPoints(CAR_BOX, rvec, tvec, camera_intrinsic, camera_distortion, pts_2d);
cv::Rect rect = cv::boundingRect(pts_2d);
cv::Rect img_rect(0, 0, image_size.width, image_size.height);
if ((rect & img_rect).area() <= 0) {
return cv::Rect(0, 0, 0, 0);
}
int base_side = std::max(rect.width, rect.height);
int max_side = std::max(image_size.width, image_size.height);
double lost_dt = target_config_->lost_time_thres_param.get();
double dt_clamped = std::max(0.0, std::min(dt, lost_dt));
int side = static_cast<int>(base_side + (max_side - base_side) * (dt_clamped / lost_dt));
if (dt >= lost_dt) {
side = max_side;
}
int cx = rect.x + rect.width / 2;
int cy = rect.y + rect.height / 2;
cv::Rect square(cx - side / 2, cy - side / 2, side, side);
square &= img_rect;
return square;
}
std::vector<std::pair<int, auto_buff::RuneFan::Simple>>
RuneTarget::match(const std::vector<auto_buff::RuneFan::Simple>& fans) {
std::vector<std::pair<int, auto_buff::RuneFan::Simple>> result;
const int n_obs = (int)(fans.size());
const int armors_num = 5;
const double GATE = target_config_->match_gate_param.get();
const double max_cost = 1e9;
std::vector<std::vector<double>> cost(n_obs, std::vector<double>(armors_num, max_cost + 1));
std::vector<MModel::VecZ> meas_list(n_obs);
for (int j = 0; j < n_obs; ++j) {
meas_list[j] = getMeasure(fans[j]);
}
for (int j = 0; j < n_obs; ++j) {
for (int id = 0; id < armors_num; ++id) {
MModel::Measure measure(id);
MModel::VecZ z_pred;
measure.h(target_state_, z_pred);
MModel::VecZ nu = meas_list[j] - z_pred;
nu[(int)MModel::Meas::YPD_Y] =
angles::normalize_angle(nu[(int)MModel::Meas::YPD_Y]);
nu[(int)MModel::Meas::YPD_P] =
angles::normalize_angle(nu[(int)MModel::Meas::YPD_P]);
nu[(int)MModel::Meas::ORI_YAW] =
angles::normalize_angle(nu[(int)MModel::Meas::ORI_YAW]);
nu[(int)MModel::Meas::ORI_ROLL] =
angles::normalize_angle(nu[(int)MModel::Meas::ORI_ROLL]);
auto R = computeMeasurementCovariance(z_pred);
double d2 = nu.transpose() * R.ldlt().solve(nu);
// 门控
if (std::isfinite(d2) && d2 < GATE) {
cost[j][id] = d2;
}
}
}
std::vector<bool> used_obs(n_obs, false);
std::vector<bool> used_id(armors_num, false);
while (true) {
double best = max_cost;
int best_j = -1;
int best_id = -1;
for (int j = 0; j < n_obs; ++j) {
if (used_obs[j])
continue;
for (int id = 0; id < armors_num; ++id) {
if (used_id[id])
continue;
if (cost[j][id] < best) {
best = cost[j][id];
best_j = j;
best_id = id;
}
}
}
if (best_j < 0 || best_id < 0) {
break;
}
used_obs[best_j] = true;
used_id[best_id] = true;
result.push_back(std::make_pair(best_id, fans[best_j]));
}
// for (auto fan: fans) {
// int id;
// auto min_angle_error = 1e10;
// const auto angles = getAngles();
// for (int i = 0; i < angles.size(); i++) {
// auto angle_error = std::abs(angles::normalize_angle(
// angles::normalize_angle(orientationToRoll(fan.target_ori)) - angles[i]
// ));
// if (angle_error < min_angle_error) {
// min_angle_error = angle_error;
// id = i;
// }
// }
// result.push_back(std::make_pair(id, fan));
// }
return result;
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,230 @@
#pragma once
#include "spd_fitter.hpp"
#include "tasks/auto_buff/rune_tracker/motion_models/motion_modelrypd.hpp"
#include "tasks/auto_buff/type.hpp"
#include "tasks/utils/utils.hpp"
#include <wust_vl/common/utils/timer.hpp>
namespace wust_vision {
namespace auto_buff {
namespace MModel = ypdrune_motion_model;
struct RuneTargetConfig: wust_vl::common::utils::ParamGroup {
static constexpr const char* kKey = "rune_tracker";
const char* key() const override {
return kKey;
}
GEN_PARAM(int, esekf_iter_num);
GEN_PARAM(double, lost_time_thres);
GEN_PARAM(int, tracking_thres);
GEN_PARAM(double, max_dis_diff);
GEN_PARAM(double, match_gate);
GEN_PARAM(double, q_roll);
GEN_PARAM(double, qyaw_output);
GEN_PARAM(double, q_xyz);
GEN_PARAM(double, q_yaw);
GEN_PARAM(double, yp_r);
GEN_PARAM(double, dis_r);
GEN_PARAM(double, yaw_r);
GEN_PARAM(double, roll_r);
GEN_PARAM(double, big_window_sec);
using Ptr = std::shared_ptr<RuneTargetConfig>;
RuneTargetConfig() {}
static Ptr create() {
return std::make_shared<RuneTargetConfig>();
}
void loadSelf(const YAML::Node& node) override {
esekf_iter_num_param.load(node);
lost_time_thres_param.load(node);
tracking_thres_param.load(node);
max_dis_diff_param.load(node);
match_gate_param.load(node);
q_roll_param.load(node);
q_xyz_param.load(node);
q_yaw_param.load(node);
yp_r_param.load(node);
dis_r_param.load(node);
roll_r_param.load(node);
big_window_sec_param.load(node);
}
};
class RuneTarget {
public:
RuneTarget() = default;
RuneTarget(
const auto_buff::RuneFan& fan,
RuneTargetConfig::Ptr target_config,
double pre_v_roll
);
bool is_big_ = false;
double last_ypd_y = 0;
bool is_inited = false;
int last_id;
std::vector<int> update_ids;
RuneTargetConfig::Ptr target_config_;
std::chrono::steady_clock::time_point last_t_;
std::chrono::steady_clock::time_point timestamp_;
std::chrono::steady_clock::time_point start_time_;
double dt_;
double last_time_ = 0;
SinSpeedFitter fitter_;
MModel::RuneESKF esekf_ypd_;
Eigen::Matrix<double, MModel::Z_N, 1> measurement_ =
Eigen::Matrix<double, MModel::Z_N, 1>::Zero();
Eigen::Matrix<double, MModel::X_N, 1> target_state_ =
Eigen::Matrix<double, MModel::X_N, 1>::Zero();
cv::Rect expanded(
Eigen::Matrix4d T_camera_to_odom,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const cv::Size& image_size
);
bool update(const auto_buff::RuneFan& fan);
void predict(std::chrono::steady_clock::time_point t);
void predict(double dt);
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z) const;
Eigen::Matrix<double, MModel::X_N, MModel::X_N> computeProcessNoise(double dt) const;
inline bool checkTargetAppear() {
bool appear = is_inited
&& wust_vl::common::utils::time_utils::durationSec(
timestamp_,
wust_vl::common::utils::time_utils::now()
) < target_config_->lost_time_thres_param.get();
return appear;
}
double predictAngle(std::chrono::steady_clock::time_point t) const {
double to_start = wust_vl::common::utils::time_utils::durationSec(start_time_, t);
return fitter_.predictAngle(to_start);
}
double predictAngle(double dt) const {
return fitter_.predictAngle(last_time_ + dt);
}
void predictWithFitter(double dt) {
if (is_big_) {
double to_start = last_time_ + dt;
double angle = fitter_.predictAngle(to_start);
double speed = fitter_.predictSpeed(to_start);
auto state = esekf_ypd_.getState();
state[(int)MModel::State::ROLL] = angles::normalize_angle(angle);
state[(int)MModel::State::VROLL] = speed;
esekf_ypd_.setState(state);
} else {
predict(dt);
}
}
void predictWithFitter(std::chrono::steady_clock::time_point t) {
double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
predictWithFitter(dt);
last_t_ = t;
}
double getFitterSpd(std::chrono::steady_clock::time_point t) {
double to_start = wust_vl::common::utils::time_utils::durationSec(start_time_, t);
return fitter_.predictSpeed(to_start);
}
Eigen::Vector3d centerPos() const {
return { target_state_((int)MModel::State::CX),
target_state_((int)MModel::State::CY),
target_state_((int)MModel::State::CZ) };
}
std::vector<double> getAngles() {
std::vector<double> angles;
for (int i = 0; i < 5; i++) {
auto angle = angles::normalize_angle(
target_state_[(int)MModel::State::ROLL] + i * 2 * M_PI / 5
);
angles.push_back(angle);
}
return angles;
}
bool diverged() const {
return diverged(target_state_);
}
bool diverged(Eigen::VectorXd target_state) const {
return false;
}
double roll() const {
return target_state_[(int)MModel::State::ROLL];
}
double curr_roll() const {
return roll() + last_id * 2 * M_PI / 5;
}
double real_roll(int id) const {
return roll() + id * 2 * M_PI / 5;
}
double yaw() const {
return target_state_[(int)MModel::State::YAW];
}
double v_roll() const {
return target_state_[(int)MModel::State::VROLL];
}
std::vector<std::pair<int, auto_buff::RuneFan::Simple>>
match(const std::vector<auto_buff::RuneFan::Simple>& fans);
std::vector<std::pair<Eigen::Vector3d, Eigen::Quaterniond>> getAllPose() const {
std::vector<std::pair<Eigen::Vector3d, Eigen::Quaterniond>> poses;
for (int i = 0; i < 5; i++) {
poses.emplace_back(getPose(i));
}
return poses;
}
std::pair<Eigen::Vector3d, Eigen::Quaterniond> getPose(int id) const {
Eigen::Vector3d euler = Eigen::Vector3d(yaw(), 0.0, real_roll(id));
auto q = utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
return computeBladeTipPose(centerPos(), q, id);
}
std::pair<Eigen::Vector3d, Eigen::Quaterniond>
computeBladeTipPose(const Eigen::Vector3d& center_pos, const Eigen::Quaterniond& q, int id)
const {
// tip 的局部坐标(沿 local X 方向)
Eigen::Vector3d local_tip(0.0, 0.0, RUNE_R2PANCENTER);
Eigen::Vector3d tip_pos = center_pos + q * local_tip;
Eigen::Vector3d euler = Eigen::Vector3d(yaw(), 0.0, real_roll(id));
return { tip_pos, utils::eulerToQuat(euler, utils::EulerOrder::ZYX) };
}
std::pair<Eigen::Vector3d, Eigen::Quaterniond> getHitPoint() const {
return getPose(last_id);
}
auto_buff::PowerRune getPowerRune() const {
auto_buff::PowerRune power_rune;
if (!is_inited) {
return power_rune;
}
power_rune.center.pos = centerPos();
Eigen::Vector3d euler = Eigen::Vector3d(yaw(), 0.0, real_roll(last_id));
auto q = Eigen::Quaterniond();
power_rune.center.ori = q;
auto all_pose = getAllPose();
for (int i = 0; i < all_pose.size(); i++) {
auto_buff::PowerRune::Pose pose;
pose.pos = all_pose[i].first;
pose.ori = all_pose[i].second;
power_rune.fans.push_back(pose);
}
power_rune.hit_id = last_id;
return power_rune;
}
Eigen::Matrix<double, MModel::Z_N, 1> getMeasure(const auto_buff::RuneFan::Simple& fan) {
auto p = fan.target_pos;
double measured_yaw = utils::orientationToYaw<RuneTarget>(fan.target_ori);
double measured_roll = utils::orientationToRoll<RuneTarget>(fan.target_ori);
double ypd_y = std::atan2(p.y(), p.x());
ypd_y = this->last_ypd_y + angles::shortest_angular_distance(this->last_ypd_y, ypd_y);
this->last_ypd_y = ypd_y;
double ypd_p = std::atan2(p.z(), std::sqrt(p.x() * p.x() + p.y() * p.y()));
double ypd_d = std::sqrt(p.x() * p.x() + p.y() * p.y() + p.z() * p.z());
Eigen::Matrix<double, MModel::Z_N, 1> measure;
measure << ypd_y, ypd_p, ypd_d, measured_yaw, measured_roll;
return measure;
}
};
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,121 @@
#include "rune_tracker.hpp"
namespace wust_vision {
namespace auto_buff {
struct RuneTracker::Impl {
public:
Impl(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) {
tracker_state = LOST;
target_ = auto_buff::RuneTarget();
target_config_ = RuneTargetConfig::create();
auto_buff_config_parameter->registerGroup(*target_config_);
auto_buff_config_parameter->reloadFromOldPath();
}
auto_buff::RuneTarget track(const auto_buff::RuneFan& fan) {
double dt = std::chrono::duration<double>(fan.timestamp - last_time_).count();
last_time_ = fan.timestamp;
lost_thres_ =
std::abs(static_cast<int>(target_config_->lost_time_thres_param.get() / dt));
bool found;
if (tracker_state == LOST) {
found = initTarget(fan);
} else {
found = updateTarget(fan);
}
updateFsm(found);
return target_;
}
void updateFsm(bool found) {
switch (tracker_state) {
case DETECTING:
if (found) {
if (++detect_count_ > target_config_->tracking_thres_param.get()) {
detect_count_ = 0;
tracker_state = TRACKING;
}
} else {
detect_count_ = 0;
tracker_state = LOST;
}
break;
case TRACKING:
if (!found) {
tracker_state = TEMP_LOST;
lost_count_ = 1;
}
break;
case TEMP_LOST:
if (!found) {
if (++lost_count_ > lost_thres_) {
lost_count_ = 0;
tracker_state = LOST;
}
} else {
lost_count_ = 0;
tracker_state = TRACKING;
}
break;
default:
break;
}
// target_.is_tracking = (tracker_state == TRACKING || tracker_state == TEMP_LOST);
if (found)
++found_count_;
// if (target_.is_tracking) {
// pre_v_roll_ = target_.v_roll();
// }
}
bool initTarget(const auto_buff::RuneFan& fan) {
if (!fan.is_valid || fan.fans.empty()) {
return false;
}
target_ = auto_buff::RuneTarget(fan, target_config_, pre_v_roll_);
tracker_state = DETECTING;
return true;
}
bool updateTarget(const auto_buff::RuneFan& fan) {
if (!fan.is_valid || fan.fans.empty()) {
return false;
}
auto fan_copy = fan;
std::erase_if(fan_copy.fans, [this](const auto_buff::RuneFan::Simple& f) {
bool pose_check = std::abs((f.target_pos - target_.centerPos()).norm())
< target_config_->max_dis_diff_param.get()
&& f.target_pos.norm() > 1.0;
return !pose_check;
});
target_.predict(fan_copy.timestamp);
return target_.update(fan_copy);
}
enum State {
LOST,
DETECTING,
TRACKING,
TEMP_LOST,
} tracker_state = LOST;
auto_buff::RuneTarget target_;
int detect_count_ = 0;
int lost_count_ = 0;
int found_count_ = 0;
double pre_v_roll_ = 0;
int lost_thres_ = 0;
std::chrono::steady_clock::time_point last_time_;
RuneTargetConfig::Ptr target_config_;
};
RuneTracker::RuneTracker(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) {
_impl = std::make_unique<Impl>(auto_buff_config_parameter);
}
RuneTracker::~RuneTracker() {
_impl.reset();
}
auto_buff::RuneTarget RuneTracker::track(const auto_buff::RuneFan& fan) {
return _impl->track(fan);
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,20 @@
#pragma once
#include "rune_target.hpp"
namespace wust_vision {
namespace auto_buff {
class RuneTracker {
public:
using Ptr = std::unique_ptr<RuneTracker>;
RuneTracker(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter);
static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_buff_config_parameter) {
return std::make_unique<RuneTracker>(auto_buff_config_parameter);
}
~RuneTracker();
auto_buff::RuneTarget track(const auto_buff::RuneFan& fan);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,265 @@
#pragma once
#include <Eigen/Dense>
#include <algorithm>
#include <atomic>
#include <ceres/ceres.h>
#include <cmath>
#include <iostream>
#include <limits>
#include <mutex>
#include <thread>
#include <vector>
namespace wust_vision {
namespace auto_buff {
class SinSpeedFitter {
public:
struct P {
double a, w, t0;
};
static constexpr double a_min_ = 0.780;
static constexpr double a_max_ = 1.045;
static constexpr double w_min_ = 1.884;
static constexpr double w_max_ = 2.000;
SinSpeedFitter() {}
void setWindow(double w) {
window_sec_ = w;
}
SinSpeedFitter(const SinSpeedFitter& other) {
std::scoped_lock lock(other.mtx_);
params_ = other.params_;
times_ = other.times_;
speeds_ = other.speeds_;
has_angle_ref_ = other.has_angle_ref_;
angle_ref_time_ = other.angle_ref_time_;
angle_ref_value_ = other.angle_ref_value_;
sign_ = other.sign_;
fitting_ = false;
}
SinSpeedFitter& operator=(const SinSpeedFitter& other) {
if (this != &other) {
std::scoped_lock lock(mtx_, other.mtx_);
params_ = other.params_;
times_ = other.times_;
speeds_ = other.speeds_;
has_angle_ref_ = other.has_angle_ref_;
angle_ref_time_ = other.angle_ref_time_;
angle_ref_value_ = other.angle_ref_value_;
sign_ = other.sign_;
fitting_ = false;
}
return *this;
}
void update(double time_s, double speed_rads) {
std::scoped_lock lock(mtx_);
auto it = std::lower_bound(times_.begin(), times_.end(), time_s);
size_t idx = std::distance(times_.begin(), it);
times_.insert(it, time_s);
speeds_.insert(speeds_.begin() + idx, speed_rads);
const double window_sec = 5.0;
if (!times_.empty()) {
double latest = times_.back();
while (!times_.empty() && latest - times_.front() > window_sec) {
times_.erase(times_.begin());
speeds_.erase(speeds_.begin());
}
}
}
void fit(bool verbose = false) {
std::scoped_lock lock(mtx_);
fitImpl(verbose);
}
void fitAsync(bool verbose = false) {
if (fitting_.exchange(true)) {
if (verbose)
std::cout << "[SinSpeedFitter] Previous async fit still running, skip.\n";
return;
}
std::vector<double> t_copy, s_copy;
P params_snapshot;
{
std::scoped_lock lock(mtx_);
t_copy = times_;
s_copy = speeds_;
params_snapshot = params_;
}
std::thread([this, t_copy, s_copy, params_snapshot, verbose]() {
fitImpl(verbose, &t_copy, &s_copy, &params_snapshot);
fitting_ = false;
}).detach();
}
double predictSpeed(double t) const {
std::scoped_lock lock(mtx_);
double a = params_.a;
double w = params_.w;
double b = 2.090 - a;
return sign_ * (a * std::sin(w * (t - params_.t0)) + b);
}
double predictAngle(double t) const {
std::scoped_lock lock(mtx_);
if (!has_angle_ref_)
return 0.0;
double a = params_.a;
double w = params_.w;
double b = 2.090 - a;
double theta = sign_ * (-a / w * std::cos(w * (t - params_.t0)) + b * (t - params_.t0));
double theta_ref = sign_
* (-a / w * std::cos(w * (angle_ref_time_ - params_.t0))
+ b * (angle_ref_time_ - params_.t0));
return angle_ref_value_ + (theta - theta_ref);
}
void setAngleRef(double time_s, double angle_rad) {
std::scoped_lock lock(mtx_);
angle_ref_time_ = time_s;
angle_ref_value_ = angle_rad;
has_angle_ref_ = true;
}
const P& params() const {
return params_;
}
int sign() const {
return sign_;
}
bool isFitting() const {
return fitting_.load();
}
private:
struct SinResidual {
SinResidual(double t, double s, int sign): t_(t), s_(s), sign_(sign) {}
template<typename T>
bool operator()(const T* const p, T* residual) const {
const T& a = p[0];
const T& w = p[1];
const T& t0 = p[2];
T b = T(2.090) - a;
T pred = T(sign_) * (a * sin(w * (T(t_) - t0)) + b);
residual[0] = T(s_) - pred;
return true;
}
double t_, s_;
int sign_;
};
bool fitImpl(
bool verbose,
const std::vector<double>* t_ptr = nullptr,
const std::vector<double>* s_ptr = nullptr,
const P* params_snapshot = nullptr
) {
const auto& t_in = t_ptr ? *t_ptr : times_;
const auto& s_in = s_ptr ? *s_ptr : speeds_;
if (t_in.size() < 3) {
if (verbose)
std::cerr << "[SinSpeedFitter] need >=3 samples\n";
return false;
}
std::vector<std::pair<double, double>> tmp;
tmp.reserve(t_in.size());
for (size_t i = 0; i < t_in.size(); ++i)
tmp.emplace_back(t_in[i], s_in[i]);
std::sort(tmp.begin(), tmp.end());
std::vector<double> t_unique, s_unique;
t_unique.reserve(tmp.size());
s_unique.reserve(tmp.size());
double last_t = std::numeric_limits<double>::quiet_NaN();
for (auto& [t, s]: tmp) {
if (t_unique.empty() || std::abs(t - last_t) > 1e-9) {
t_unique.push_back(t);
s_unique.push_back(s);
last_t = t;
}
}
P params_initial = params_snapshot ? *params_snapshot : params_;
double err_pos = fit_with_sign(+1, t_unique, s_unique, params_initial, verbose);
double err_neg = fit_with_sign(-1, t_unique, s_unique, params_initial, verbose);
std::scoped_lock lock(mtx_);
sign_ = (err_pos <= err_neg) ? +1 : -1;
return true;
}
double fit_with_sign(
int sgn,
const std::vector<double>& t_unique,
const std::vector<double>& s_unique,
P params_initial,
bool verbose
) {
ceres::Problem problem;
double params[3] = { params_initial.a, params_initial.w, params_initial.t0 };
for (size_t i = 0; i < t_unique.size(); ++i) {
problem.AddResidualBlock(
new ceres::AutoDiffCostFunction<SinResidual, 1, 3>(
new SinResidual(t_unique[i], s_unique[i], sgn)
),
nullptr,
params
);
}
problem.SetParameterLowerBound(params, 0, a_min_);
problem.SetParameterUpperBound(params, 0, a_max_);
problem.SetParameterLowerBound(params, 1, w_min_);
problem.SetParameterUpperBound(params, 1, w_max_);
ceres::Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
options.max_num_iterations = 100;
options.minimizer_progress_to_stdout = verbose;
ceres::Solver::Summary summary;
ceres::Solve(options, &problem, &summary);
double err_sum = 0.0;
for (size_t i = 0; i < t_unique.size(); ++i) {
double pred = sgn
* (params[0] * std::sin(params[1] * (t_unique[i] - params[2]))
+ (2.090 - params[0]));
double e = s_unique[i] - pred;
err_sum += e * e;
}
if (verbose)
std::cout << (sgn > 0 ? "[+] " : "[-] ") << summary.BriefReport()
<< " error=" << err_sum << std::endl;
std::scoped_lock lock(mtx_);
params_.a = params[0];
params_.w = params[1];
params_.t0 = params[2];
return err_sum;
}
private:
mutable std::mutex mtx_;
P params_ { 1.0, 1.9, 0.0 };
std::vector<double> times_;
std::vector<double> speeds_;
bool has_angle_ref_ = false;
double angle_ref_time_ = 0.0;
double angle_ref_value_ = 0.0;
int sign_ = 1;
std::atomic<bool> fitting_ { false };
double window_sec_ = 1.0;
};
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,213 @@
#include "rune_where.hpp"
#include "tasks/utils/utils.hpp"
#include <Eigen/Dense>
#include <opencv2/core/eigen.hpp>
namespace wust_vision {
namespace auto_buff {
struct RuneWhere::Impl {
public:
Impl(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
camera_info_ = camera_info;
}
struct Params {
enum class OptMode : int { GOLDEN = 0, CERES = 1, NONE = 2 } opt_mode;
OptMode fromString(const std::string& mode) {
if (mode == "golden" || mode == "GOLDEN") {
return OptMode::GOLDEN;
} else if (mode == "none" || mode == "NONE") {
return OptMode::NONE;
} else {
return OptMode::NONE;
}
}
int golden_search_side_deg = 60;
void load(const YAML::Node& node) {
opt_mode = fromString(node["roll_opt"]["mode"].as<std::string>());
golden_search_side_deg = node["roll_opt"]["golden_search_side_deg"].as<int>();
}
} params_;
auto_buff::RuneFan
where(auto_buff::RuneFan f, Eigen::Matrix4d T_camera_to_odom) const noexcept {
const Eigen::Matrix3d R_imu_cam = T_camera_to_odom.block<3, 3>(0, 0);
for (auto& fan: f.fans) {
cv::Mat rvec, tvec;
cv::solvePnP(
fan.getObjs(),
fan.landmarks(),
camera_info_.first,
camera_info_.second,
rvec,
tvec,
false,
cv::SOLVEPNP_IPPE //平移更稳定,(旋转这里纯靠后面优化)
);
cv::Mat R_cv;
cv::Rodrigues(rvec, R_cv);
Eigen::Matrix3d R = utils::cvToEigen(R_cv);
Eigen::Vector3d t = utils::cvToEigen(tvec);
if (params_.opt_mode != Params::OptMode::NONE) {
R = solveBa_R(fan, t, R, R_imu_cam);
}
fan.ori = Eigen::Quaterniond(R);
fan.pos = t;
Eigen::Vector3d pos_camera = fan.pos;
fan.target_pos = utils::transformPosition(pos_camera, T_camera_to_odom);
const Eigen::Quaterniond
q_camera(fan.ori.w(), fan.ori.x(), fan.ori.y(), fan.ori.z());
const Eigen::Quaterniond q_odom =
utils::transformOrientation(q_camera, T_camera_to_odom);
fan.target_ori = q_odom;
f.is_valid = true;
}
return f;
}
std::vector<Eigen::Vector2d> reprojection(
double roll,
const std::vector<cv::Point3f>& object_points,
const std::vector<cv::Point2f>& landmarks,
const Eigen::Matrix3d& Rci,
double pitch,
double yaw,
const Eigen::Vector3d& t
) const noexcept {
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY());
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
const Eigen::Matrix3d R = Rci * (ay * ap * ar).toRotationMatrix();
cv::Mat rvec, R_cv;
cv::eigen2cv(R, R_cv);
cv::Rodrigues(R_cv, rvec);
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << t.x(), t.y(), t.z());
std::vector<cv::Point2f> pts_2d;
pts_2d.reserve(object_points.size());
cv::projectPoints(
object_points,
rvec,
tvec,
camera_info_.first,
camera_info_.second,
pts_2d
);
std::vector<Eigen::Vector2d> image_points;
image_points.reserve(pts_2d.size());
for (const auto& p: pts_2d) {
image_points.emplace_back(p.x, p.y);
}
return image_points;
}
double reprojectionErrorRoll(
double roll,
const std::vector<cv::Point3f>& obj,
const std::vector<cv::Point2f>& lm,
const Eigen::Matrix3d& Rci,
double pitch,
double yaw,
const Eigen::Vector3d& t
) const noexcept {
const auto image_points = reprojection(roll, obj, lm, Rci, pitch, yaw, t);
double cost = 0.0;
for (size_t i = 0; i < image_points.size(); ++i) {
Eigen::Vector2d obs(lm[i].x, lm[i].y);
cost += (image_points[i] - obs).squaredNorm();
}
return cost;
}
double goldenRoll(
double init,
const std::vector<cv::Point3f>& obj,
const std::vector<cv::Point2f>& lm,
const Eigen::Matrix3d& Rci,
double pitch,
double yaw,
const Eigen::Vector3d& t
) const noexcept {
constexpr double phi = 1.618033988749894848; // golden ratio
double l = init - params_.golden_search_side_deg * M_PI / 180.0;
double r = init + params_.golden_search_side_deg * M_PI / 180.0;
double r1 = r - (r - l) / phi;
double r2 = l + (r - l) / phi;
double f1 = reprojectionErrorRoll(r1, obj, lm, Rci, pitch, yaw, t);
double f2 = reprojectionErrorRoll(r2, obj, lm, Rci, pitch, yaw, t);
while (r - l > 0.0001) { // 约 0.0057 度
if (f1 < f2) {
r = r2;
r2 = r1;
f2 = f1;
r1 = r - (r - l) / phi;
f1 = reprojectionErrorRoll(r1, obj, lm, Rci, pitch, yaw, t);
} else {
l = r1;
r1 = r2;
f1 = f2;
r2 = l + (r - l) / phi;
f2 = reprojectionErrorRoll(r2, obj, lm, Rci, pitch, yaw, t);
}
}
return 0.5 * (l + r); // final best roll
}
Eigen::Matrix3d solveBa_R(
const auto_buff::RuneFan::Simple& rune_fan,
const Eigen::Vector3d& t_camera_armor,
const Eigen::Matrix3d& R_camera_armor,
const Eigen::Matrix3d& R_imu_camera
) const noexcept {
Eigen::Matrix3d R_imu_armor = R_imu_camera * R_camera_armor;
Eigen::Matrix3d R_camera_imu = R_imu_camera.transpose();
double initial_roll = std::atan2(R_imu_armor(2, 1), R_imu_armor(2, 2));
double roll = initial_roll;
Eigen::Vector3d t_imu_armor = R_imu_camera * t_camera_armor;
double yaw = std::atan2(t_imu_armor.y(), t_imu_armor.x());
double pitch = 0;
auto cv_points = rune_fan.getObjs();
const auto& landmarks = rune_fan.landmarks();
if (params_.opt_mode == Params::OptMode::GOLDEN) {
roll = goldenRoll(
roll,
cv_points,
landmarks,
R_camera_imu,
pitch,
yaw,
t_camera_armor
);
}
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY());
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
const Eigen::Matrix3d R_result = R_camera_imu * (ay * ap * ar).toRotationMatrix();
return R_result;
}
std::pair<cv::Mat, cv::Mat> camera_info_;
};
RuneWhere::RuneWhere(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
_impl = std::make_unique<Impl>(config, camera_info);
}
RuneWhere::~RuneWhere() {
_impl.reset();
}
auto_buff::RuneFan RuneWhere::where(auto_buff::RuneFan f, Eigen::Matrix4d T_camera_to_odom) {
return _impl->where(f, T_camera_to_odom);
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,39 @@
// Created by Labor 2023.8.25
// Maintained by Labor, Chengfu Zou
// Copyright (C) FYT Vision Group. All rights reserved.
// Copyright 2025 XiaoJian Wu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tasks/auto_buff/type.hpp"
namespace wust_vision {
namespace auto_buff {
class RuneWhere {
public:
using Ptr = std::unique_ptr<RuneWhere>;
RuneWhere(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info);
static Ptr
create(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
return std::make_unique<RuneWhere>(config, camera_info);
}
~RuneWhere();
auto_buff::RuneFan where(auto_buff::RuneFan f, Eigen::Matrix4d T_camera_to_odom);
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,339 @@
#include "type.hpp"
#include "tasks/utils/utils.hpp"
namespace wust_vision {
namespace auto_buff {
void RunePan::draw(cv::Mat& img, const cv::Point2f& offset) const {
if (!is_valid || corners.size() < 3)
return;
std::vector<cv::Point2f> sorted_corners = corners;
for (auto& pt: sorted_corners) {
pt += offset;
}
// 画边
for (size_t i = 0; i < sorted_corners.size(); ++i) {
cv::line(
img,
sorted_corners[i],
sorted_corners[(i + 1) % sorted_corners.size()],
cv::Scalar(0, 255, 255),
2
);
}
// 画中心点
cv::circle(img, center, 3, cv::Scalar(255, 0, 255), -1);
if (has_refer) {
// 画角点编号
for (size_t i = 0; i < sorted_corners.size(); ++i) {
cv::Point2f p = sorted_corners[i];
// 绘制角点位置
cv::circle(img, p, 3, cv::Scalar(0, 0, 255), -1);
// 让数字稍微往右下偏移,避免盖到角点
cv::Point2f text_pos = p + cv::Point2f(5, -5);
cv::putText(
img,
std::to_string(i),
text_pos,
cv::FONT_HERSHEY_SIMPLEX,
0.5,
cv::Scalar(0, 255, 0),
1
);
}
}
}
double RunePan::getArea() const {
if (corners.size() < 3)
return 0.0;
std::vector<cv::Point2f> sorted_corners = corners;
std::sort(
sorted_corners.begin(),
sorted_corners.end(),
[this](const cv::Point2f& a, const cv::Point2f& b) {
double angA = std::atan2(a.y - center.y, a.x - center.x);
double angB = std::atan2(b.y - center.y, b.x - center.x);
return angA < angB;
}
);
return cv::contourArea(sorted_corners);
}
void RunePan::addReferRuneCenter(const RuneCenter& rc) {
if (!rc.is_valid || !is_valid)
return;
if (corners.size() != 4)
return;
cv::Point2f down_vec = rc.center - center;
float norm = std::sqrt(down_vec.x * down_vec.x + down_vec.y * down_vec.y);
if (norm < 1e-6f)
return;
has_refer = true;
float angle_ref = std::atan2(down_vec.y, down_vec.x);
// 获取4个点在旋转后的角度
struct Node {
float ang;
cv::Point2f p;
};
std::vector<Node> arr;
arr.reserve(4);
for (auto& p: corners) {
cv::Point2f v = p - center;
// 旋转坐标,使 down_vec 对齐 angle=0
float ang = std::atan2(v.y, v.x) - angle_ref;
// 归一化到 (-π, π]
while (ang <= -CV_PI)
ang += 2 * CV_PI;
while (ang > CV_PI)
ang -= 2 * CV_PI;
arr.push_back({ ang, p });
}
// 按角度排序(从 -π 到 π)
std::sort(arr.begin(), arr.end(), [](const Node& a, const Node& b) {
return a.ang < b.ang;
});
// 准备象限变量并标记
cv::Point2f lu(0, 0), ru(0, 0), rd(0, 0), ld(0, 0);
bool has_lu = false, has_ru = false, has_rd = false, has_ld = false;
for (const auto& n: arr) {
float a = n.ang;
if (a > CV_PI / 2 && a <= CV_PI) {
lu = n.p;
has_lu = true;
} else if (a > 0 && a <= CV_PI / 2) {
ru = n.p;
has_ru = true;
} else if (a > -CV_PI / 2 && a <= 0) {
rd = n.p;
has_rd = true;
} else { // a > -CV_PI && a <= -CV_PI/2
ld = n.p;
has_ld = true;
}
}
std::array<cv::Point2f, 4> ordered;
if (has_lu && has_ru && has_rd && has_ld) {
ordered[0] = lu;
ordered[1] = ru;
ordered[2] = rd;
ordered[3] = ld;
corners.assign(ordered.begin(), ordered.end());
return;
}
float target = 3.0f * CV_PI / 4.0f; // 135°
int best_idx = 0;
float best_diff = std::numeric_limits<float>::max();
for (int i = 0; i < (int)arr.size(); ++i) {
float d = std::fabs(angles::shortest_angular_distance(target, arr[i].ang)
); // 如果没有 angles::shortest_angular_distance可以用下面替代
if (d < best_diff) {
best_diff = d;
best_idx = i;
}
}
for (int i = 0; i < 4; ++i) {
int idx = (best_idx + i) % 4;
ordered[i] = arr[idx].p;
}
corners.assign(ordered.begin(), ordered.end());
}
void RuneFan::Simple::addOther(const Simple& other) {
auto l1 = points2d[0] - points2d[5];
auto l2 = other.points2d[0] - other.points2d[5];
float a1 = std::atan2(l1.y, l1.x);
float a2 = std::atan2(l2.y, l2.x);
float d = a1 - a2;
d = normalizeAngle0to2pi(d);
int id = 0;
double min_err = 1e9;
for (int i = 0; i < angle_diffs.size(); i++) {
double err = std::abs(angle_diffs[i] - d);
if (err < min_err) {
min_err = err;
id = i;
}
}
if (id < 1) {
return;
}
has_other++;
points2d.push_back(other.points2d[1]);
points2d.push_back(other.points2d[2]);
points2d.push_back(other.points2d[3]);
points2d.push_back(other.points2d[4]);
double roll = -angle_diffs[id];
points3d.push_back(rotateX(points3d[1], roll));
points3d.push_back(rotateX(points3d[2], roll));
points3d.push_back(rotateX(points3d[3], roll));
points3d.push_back(rotateX(points3d[4], roll));
}
void RuneFan::Simple::drawLandmarks(cv::Mat& image) const {
std::vector<cv::Point2f> lm = landmarks();
for (size_t i = 0; i < lm.size(); i++) {
cv::circle(image, lm[i], 3, cv::Scalar(255, 255, 255), -1);
if (i == 0) {
cv::putText(
image,
"R",
lm[i],
cv::FONT_HERSHEY_SIMPLEX,
1.5,
cv::Scalar(40, 255, 40),
2
);
} else {
cv::putText(
image,
std::to_string(i),
lm[i],
cv::FONT_HERSHEY_SIMPLEX,
0.5,
cv::Scalar(255, 255, 255),
2
);
}
}
}
void RuneFan::addOffset(const cv::Point2f& offset) {
for (auto& fan: fans) {
for (auto& point: fan.points2d) {
point += offset;
}
}
}
void RuneFan::transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) {
for (auto& fan: fans) {
for (auto& pt: fan.points2d) {
pt = utils::transformPoint2D(transform_matrix, pt);
}
}
}
void PowerRune::Pose::tf(Eigen::Matrix4d T_camera_to_odom) {
Eigen::Vector4d pos_camera(pos.x(), pos.y(), pos.z(), 1.0);
Eigen::Vector4d pos_odom = T_camera_to_odom * pos_camera;
pos.x() = pos_odom.x();
pos.y() = pos_odom.y();
pos.z() = pos_odom.z();
Eigen::Matrix3d R_camera_to_odom = T_camera_to_odom.block<3, 3>(0, 0);
Eigen::Quaterniond q_camera(ori.w(), ori.x(), ori.y(), ori.z());
Eigen::Matrix3d R_ori_camera = q_camera.normalized().toRotationMatrix();
Eigen::Matrix3d R_ori_odom = R_camera_to_odom * R_ori_camera;
Eigen::Quaterniond q_odom(R_ori_odom);
ori.w() = q_odom.w();
ori.x() = q_odom.x();
ori.y() = q_odom.y();
ori.z() = q_odom.z();
}
std::vector<cv::Point2f> PowerRune::Pose::toPts(
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const std::vector<cv::Point3f>& obj_points
) const {
std::vector<cv::Point2f> pts;
if (pos.norm() < 0.5) {
return pts;
}
cv::Mat tvec = (cv::Mat_<double>(3, 1) << pos.x(), pos.y(), pos.z());
Eigen::Matrix3d tf_rot = ori.toRotationMatrix();
cv::Mat rot_mat =
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
tf_rot(0, 1),
tf_rot(0, 2),
tf_rot(1, 0),
tf_rot(1, 1),
tf_rot(1, 2),
tf_rot(2, 0),
tf_rot(2, 1),
tf_rot(2, 2));
// 旋转矩阵 -> 旋转向量
cv::Mat rvec;
cv::Rodrigues(rot_mat, rvec);
cv::projectPoints(obj_points, rvec, tvec, camera_intrinsic, camera_distortion, pts);
return pts;
}
void PowerRune::Pose::draw(
cv::Mat& img,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const std::vector<cv::Point3f>& obj_points,
cv::Scalar color
) const {
auto pts = toPts(camera_intrinsic, camera_distortion, obj_points);
if (!pts.empty()) {
for (int i = 0; i < 4; i++)
cv::line(img, pts[i], pts[(i + 1) % 4], color, 2);
// 后表面
for (int i = 4; i < 8; i++)
cv::line(img, pts[i], pts[4 + (i + 1) % 4], color, 2);
// 侧边
for (int i = 0; i < 4; i++)
cv::line(img, pts[i], pts[i + 4], color, 2);
cv::Point2f center(0.f, 0.f);
for (auto pt: pts) {
center += pt;
}
center *= 1.0 / pts.size();
}
}
void PowerRune::tf(Eigen::Matrix4d T_camera_to_odom) {
center.tf(T_camera_to_odom);
for (auto& fan: fans)
fan.tf(T_camera_to_odom);
}
void
PowerRune::draw(cv::Mat& img, const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion)
const {
center.draw(img, camera_intrinsic, camera_distortion);
for (int i = 0; i < fans.size(); i++) {
if (i == hit_id)
fans[i].draw(
img,
camera_intrinsic,
camera_distortion,
FAN_BLOCK,
cv::Scalar(40, 255, 40)
);
else
fans[i].draw(img, camera_intrinsic, camera_distortion, FAN_BLOCK);
}
}
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,124 @@
#pragma once
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_buff {
constexpr double RUNE_PAN_BOX_DIS = 0.16;
constexpr double RUNE_R2PANCENTER = 0.75;
struct RuneCenter {
cv::Point2f center;
cv::RotatedRect rr;
bool is_valid = false;
RuneCenter() = default;
RuneCenter(cv::RotatedRect rect): rr(rect) {
center = rect.center;
is_valid = rr.size.area() > 0;
}
};
class RunePan {
public:
cv::Point2f center;
std::vector<cv::Point2f> corners;
bool is_valid = false;
bool has_refer = false;
void draw(cv::Mat& img, const cv::Point2f& offset) const;
double getArea() const;
void addReferRuneCenter(const RuneCenter& rc);
};
struct RuneFan {
public:
bool is_valid = false;
int id;
bool is_big = false;
std::chrono::steady_clock::time_point timestamp;
struct Simple {
int has_other = 0;
std::vector<double> angle_diffs = { 0,
2 * M_PI / 5,
2 * M_PI / 5 * 2,
2 * M_PI / 5 * 3,
2 * M_PI / 5 * 4 };
std::vector<cv::Point2f> points2d;
std::vector<cv::Point3f> points3d = {
{ 0.0f, 0.0f, 0.0f }, // P0
{ 0.0f, RUNE_PAN_BOX_DIS / 2.0f, RUNE_R2PANCENTER + RUNE_PAN_BOX_DIS / 2.0f }, // P1
{ 0.0f, RUNE_PAN_BOX_DIS / 2.0f, RUNE_R2PANCENTER - RUNE_PAN_BOX_DIS / 2.0f }, // P2
{ 0.0f,
-RUNE_PAN_BOX_DIS / 2.0f,
RUNE_R2PANCENTER - RUNE_PAN_BOX_DIS / 2.0f }, // P3
{ 0.0f,
-RUNE_PAN_BOX_DIS / 2.0f,
RUNE_R2PANCENTER + RUNE_PAN_BOX_DIS / 2.0f }, // P4
{ 0.0f, 0.0f, RUNE_R2PANCENTER } // P5
};
inline cv::Point3f rotateX(const cv::Point3f& p, double roll) {
double c = std::cos(roll);
double s = std::sin(roll);
return { p.x, float(p.y * c - p.z * s), float(p.y * s + p.z * c) };
}
inline double normalizeAngle0to2pi(double a) {
a = std::fmod(a, 2 * M_PI);
if (a < 0)
a += 2 * M_PI;
return a;
}
Eigen::Vector3d pos;
Eigen::Quaterniond ori;
Eigen::Vector3d target_pos;
Eigen::Quaterniond target_ori;
void addOther(const Simple& other);
std::vector<cv::Point2f> landmarks() const {
return points2d;
}
void drawLandmarks(cv::Mat& image) const;
std::vector<cv::Point3f> getObjs() const {
return points3d;
}
};
std::vector<Simple> fans;
void addOffset(const cv::Point2f& offset);
void transform(const Eigen::Matrix<float, 3, 3>& transform_matrix);
};
static std::vector<cv::Point3f> FAN_BLOCK = {
{ -0.05f, -0.20f, -0.15f }, // 0: 左下前
{ 0.05f, -0.20f, -0.15f }, // 1: 右下前
{ 0.05f, 0.20f, -0.15f }, // 2: 右上前
{ -0.05f, 0.20f, -0.15f }, // 3: 左上前
{ -0.05f, -0.20f, 0.15f }, // 4: 左下后
{ 0.05f, -0.20f, 0.15f }, // 5: 右下后
{ 0.05f, 0.20f, 0.15f }, // 6: 右上后
{ -0.05f, 0.20f, 0.15f } // 7: 左上后
};
struct PowerRune {
bool is_valid = false;
struct Pose {
Eigen::Vector3d pos;
Eigen::Quaterniond ori;
void tf(Eigen::Matrix4d T_camera_to_odom);
std::vector<cv::Point2f> toPts(
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const std::vector<cv::Point3f>& obj_points = AIM_TARGET_BLOCK
) const;
void draw(
cv::Mat& img,
const cv::Mat& camera_intrinsic,
const cv::Mat& camera_distortion,
const std::vector<cv::Point3f>& obj_points = AIM_TARGET_BLOCK,
cv::Scalar color = cv::Scalar(255, 255, 255)
) const;
};
Pose center;
std::vector<Pose> fans;
int hit_id;
void tf(Eigen::Matrix4d T_camera_to_odom);
void
draw(cv::Mat& img, const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion) const;
};
} // namespace auto_buff
} // namespace wust_vision

View File

@@ -0,0 +1,198 @@
#include "auto_guidance.hpp"
#include "tasks/auto_guidance/guidance_detector/detector_base.hpp"
#include "tasks/auto_guidance/guidance_detector/detector_factory.hpp"
#include "tasks/auto_guidance/guidance_tracker/guidance_tracker.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/common/concurrency/queues.hpp"
#include "wust_vl/common/utils/logger.hpp"
#include "wust_vl/common/utils/timer.hpp"
namespace wust_vision {
namespace auto_guidance {
struct AutoGuidance::Impl {
~Impl() {
lights_queue_->stop();
if (processing_thread_) {
processing_thread_->stop();
wust_vl::common::concurrency::ThreadManager::instance().unregisterThread(
processing_thread_->getName()
);
}
}
void init(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
camera_info_ = camera_info;
std::string backend = config["backend"].as<std::string>();
std::cout << "backend: " << backend << std::endl;
auto detector_cfg = config["detector"];
detector_ = DetectorFactory::createDetector(backend, detector_cfg, debug_);
detector_->setCallback(std::bind(
&AutoGuidance::Impl::detectCallback,
this,
std::placeholders::_1,
std::placeholders::_2
));
tracker_ = GuidanceTracker::create(config["tracker"]);
lights_queue_ =
std::make_unique<wust_vl::common::concurrency::OrderedQueue<GreenLights>>(100, 500);
latency_averager_ =
std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
}
void pushInput(CommonFrame& frame) {
img_recv_count_++;
if (detector_) {
detector_->pushInput(frame);
}
}
void detectCallback(const std::vector<GreenLight>& objs, const CommonFrame& frame) {
detect_finish_count_++;
GreenLights lights;
lights.lights = objs;
lights.timestamp = frame.img_frame.timestamp;
lights.id = frame.id;
for (auto& light: lights.lights) {
light.solvePnP(camera_info_.first, camera_info_.second);
light.timestamp = frame.img_frame.timestamp;
light.image_size = frame.img_frame.src_img.size();
}
green_lights_ = lights;
lights_queue_->enqueue(lights);
if (debug_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
dbg_.lights = lights;
dbg_.img_frame = frame.img_frame;
}
}
void lightsCallback(const GreenLights& lights) {
if (lights.timestamp <= tracker_->getLastTime()) {
WUST_WARN(logger_) << "Received out-of-order armor data, discarded.";
return;
}
GuidanceTarget target;
target = tracker_->track(lights);
{
std::lock_guard<std::mutex> lock(target_mutex_);
guidance_target_ = target;
}
auto now = std::chrono::steady_clock::now();
auto latency_ms = wust_vl::common::utils::time_utils::durationMs(lights.timestamp, now);
latency_averager_->add(latency_ms);
dbg_.latency_ms = latency_averager_->average();
if (debug_) {
std::lock_guard<std::mutex> lock(dbg_mutex_);
dbg_.target = target;
}
}
void start() {
processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create(
"AutoAimProcessingThread",
[this](wust_vl::common::concurrency::MonitoredThread::Ptr self) {
this->processingLoop(self);
}
);
wust_vl::common::concurrency::ThreadManager::instance().registerThread(
processing_thread_
);
run_flag_ = true;
}
void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) {
while (!self->isAlive()) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
while (self->isAlive()) {
if (!self->waitPoint())
break;
self->heartbeat();
printStats();
GreenLights lights;
bool skip;
// if (lights_queue_->dequeue_wait(lights, skip)) {
// lightsCallback(lights);
// tracker_finish_count_++;
// if (skip) {
// WUST_DEBUG(logger_) << "OrderQueue skip";
// }
// }
if (!lights_queue_->try_dequeue(lights)) {
std::this_thread::sleep_for(std::chrono::milliseconds(3));
continue;
}
lightsCallback(lights);
tracker_finish_count_++;
}
}
GuidanceTarget getTarget() {
timer_count_++;
std::lock_guard<std::mutex> lock(target_mutex_);
return guidance_target_;
}
void printStats() {
utils::XSecOnce(
[&] {
WUST_INFO(logger_)
<< "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_
<< ", Fin: " << tracker_finish_count_ << ", Lat: " << dbg_.latency_ms
<< "ms"
<< ", Tc:" << timer_count_;
img_recv_count_ = 0;
detect_finish_count_ = 0;
tracker_finish_count_ = 0;
timer_count_ = 0;
},
1.0
);
}
std::unique_ptr<detector_base> detector_;
std::string logger_ = "auto_guidance";
std::chrono::steady_clock::time_point last_stat_time_steady_ =
std::chrono::steady_clock::now();
GuidanceTracker::Ptr tracker_;
bool run_flag_ = false;
int detect_finish_count_ = 0;
int img_recv_count_ = 0;
int tracker_finish_count_ = 0;
int timer_count_ = 0;
bool debug_ = false;
GuidanceTarget guidance_target_;
GreenLights green_lights_;
std::shared_ptr<wust_vl::common::concurrency::MonitoredThread> processing_thread_;
std::unique_ptr<wust_vl::common::concurrency::OrderedQueue<GreenLights>> lights_queue_;
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
std::pair<cv::Mat, cv::Mat> camera_info_;
AutoGuidanceDebug dbg_;
std::mutex target_mutex_;
std::mutex dbg_mutex_;
};
AutoGuidance::AutoGuidance(): _impl(std::make_unique<Impl>()) {}
AutoGuidance::~AutoGuidance() {
_impl.reset();
}
void
AutoGuidance::init(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
_impl->init(config, camera_info);
}
void AutoGuidance::start() {
_impl->start();
}
void AutoGuidance::pushInput(CommonFrame& frame) {
_impl->pushInput(frame);
}
void AutoGuidance::setDebug(bool debug) {
_impl->debug_ = debug;
}
AutoGuidanceDebug AutoGuidance::getDebug() {
std::lock_guard<std::mutex> lock(_impl->dbg_mutex_);
return _impl->dbg_;
}
GuidanceTarget AutoGuidance::getTarget() {
return _impl->getTarget();
}
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,24 @@
#pragma once
#include "debug.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_guidance {
class AutoGuidance {
public:
static inline std::unique_ptr<AutoGuidance> create() {
return std::make_unique<AutoGuidance>();
}
AutoGuidance();
~AutoGuidance();
void init(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info);
void start();
void pushInput(CommonFrame& frame);
void setDebug(bool debug);
GuidanceTarget getTarget();
AutoGuidanceDebug getDebug();
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,295 @@
#include "debug.hpp"
#include <fcntl.h>
#include <fmt/format.h>
#include <fstream>
#include <nlohmann/json.hpp>
#include <sys/mman.h>
#include <sys/stat.h>
namespace wust_vision {
namespace auto_guidance {
void drawAutoGuidanceDebugContent(cv::Mat& debug_img, const AutoGuidanceDebug& dbg) {
auto target = dbg.target;
auto lights = dbg.lights;
lights.drawFront(debug_img);
if (target.is_tracking_) {
auto now = std::chrono::steady_clock::now();
target.predict(now);
target.draw(debug_img);
}
std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms);
cv::putText(
debug_img,
latency_str,
cv::Point(10, 30),
cv::FONT_HERSHEY_SIMPLEX,
0.8,
cv::Scalar(255, 255, 255),
2
);
cv::circle(
debug_img,
cv::Point2i(debug_img.cols / 2, debug_img.rows / 2),
5,
cv::Scalar(255, 255, 255),
2
);
double cx_norm = target.center().x / target.image_size_.width * 2.0 - 1.0;
double diff_center_norm = (target.is_tracking_) ? cx_norm : 0;
{
std::string diff_str = fmt::format("diff_cx_norm: {:.3f}", diff_center_norm);
int margin = 10;
int font_face = cv::FONT_HERSHEY_SIMPLEX;
double font_scale = 0.7;
int thickness = 2;
int baseline = 0;
cv::Size text_size =
cv::getTextSize(diff_str, font_face, font_scale, thickness, &baseline);
// 右上角文本左下角坐标
int x = debug_img.cols - margin - text_size.width;
int y = margin + text_size.height;
// 背景框,可选
cv::rectangle(
debug_img,
cv::Rect(
x - 5,
y - text_size.height - 5,
text_size.width + 10,
text_size.height + 10
),
cv::Scalar(0, 0, 0),
cv::FILLED,
cv::LINE_AA
);
cv::putText(
debug_img,
diff_str,
cv::Point(x, y),
font_face,
font_scale,
cv::Scalar(0, 255, 255),
thickness,
cv::LINE_AA
);
}
if (target.is_tracking_) {
const auto& s = target.target_state_;
std::string line1 =
fmt::format("pos: {:.1f} {:.1f} {:.1f} {:.1f}", s(0), s(2), s(4), s(6));
std::string line2 =
fmt::format("vel: {:.2f} {:.2f} {:.2f} {:.2f}", s(1), s(3), s(5), s(7));
int x = 10;
int y = debug_img.rows - 30; // 左下角位置
int dy = 28;
cv::putText(
debug_img,
line1,
cv::Point(x, y),
cv::FONT_HERSHEY_SIMPLEX,
0.75,
cv::Scalar(0, 255, 0),
2
);
cv::putText(
debug_img,
line2,
cv::Point(x, y + dy),
cv::FONT_HERSHEY_SIMPLEX,
0.75,
cv::Scalar(0, 255, 0),
2
);
double vx = s(1);
double vy = s(3);
cv::Point2f p0 = target.center();
double scale = 0.5;
cv::Point2f p1(p0.x + vx * scale, p0.y + vy * scale);
cv::arrowedLine(
debug_img,
p0,
p1,
cv::Scalar(0, 255, 0),
2,
cv::LINE_AA,
0,
0.25 // 箭头比例
);
}
}
void drawDebugOverlayWrite(const AutoGuidanceDebug& dbg, bool auto_fps) {
static auto last_show_time = std::chrono::steady_clock::now();
if (dbg.img_frame.src_img.empty())
return;
cv::Mat src_img = dbg.img_frame.src_img;
auto now = std::chrono::steady_clock::now();
const double min_interval_ms = 1000.0 / 30.0;
if (std::chrono::duration<double, std::milli>(now - last_show_time).count()
< min_interval_ms
&& auto_fps)
return;
last_show_time = now;
// 图像构造
cv::Mat debug_img;
src_img.convertTo(debug_img, -1, 1, 0);
cv::cvtColor(debug_img, debug_img, cv::COLOR_BGR2RGB);
if (debug_img.empty())
return;
// 封装后的绘图函数
drawAutoGuidanceDebugContent(debug_img, dbg);
cv::cvtColor(debug_img, debug_img, cv::COLOR_RGB2BGR);
// 编码写入共享内存路径
std::vector<uchar> buf;
cv::imencode(".jpg", debug_img, buf);
std::ofstream ofs("/dev/shm/debug_frame.jpg.tmp", std::ios::binary);
ofs.write(reinterpret_cast<const char*>(buf.data()), buf.size());
ofs.close();
std::rename("/dev/shm/debug_frame.jpg.tmp", "/dev/shm/debug_frame.jpg");
}
void drawDebugOverlayShm(const AutoGuidanceDebug& dbg, bool auto_fps) {
static auto last_show_time = std::chrono::steady_clock::now();
const char* shm_name = "/debug_frame";
const size_t shm_max_size = 2 * 1024 * 1024; // 2MB 最大图像编码缓存
if (dbg.img_frame.src_img.empty())
return;
cv::Mat src_img = dbg.img_frame.src_img;
auto now = std::chrono::steady_clock::now();
const double min_interval_ms = 1000.0 / 30.0;
if (std::chrono::duration<double, std::milli>(now - last_show_time).count()
< min_interval_ms
&& auto_fps)
return;
last_show_time = now;
// 复制并转RGB
cv::Mat debug_img;
src_img.convertTo(debug_img, -1, 1, 0);
cv::cvtColor(debug_img, debug_img, cv::COLOR_BGR2RGB);
if (debug_img.empty())
return;
// 绘制内容
drawAutoGuidanceDebugContent(debug_img, dbg);
// 编码为 JPG
std::vector<uchar> buf;
cv::imencode(".jpg", debug_img, buf);
size_t img_size = buf.size();
if (img_size > shm_max_size) {
std::cerr << "[drawDebugOverlayWrite] 图像过大: " << img_size << " bytes\n";
return;
}
// 创建/打开共享内存
int fd = shm_open(shm_name, O_CREAT | O_RDWR, 0666);
if (fd == -1) {
perror("shm_open failed");
return;
}
// 设置共享内存大小
if (ftruncate(fd, shm_max_size) == -1) {
perror("ftruncate failed");
close(fd);
return;
}
// 映射共享内存
void* ptr = mmap(nullptr, shm_max_size, PROT_WRITE, MAP_SHARED, fd, 0);
if (ptr == MAP_FAILED) {
perror("mmap failed");
close(fd);
return;
}
// 写入图像数据
uint32_t size = static_cast<uint32_t>(img_size);
std::memcpy(ptr, &size, 4); // 前4字节写入长度
std::memcpy(static_cast<char*>(ptr) + 4, buf.data(), img_size);
// 关闭映射和文件描述符
munmap(ptr, shm_max_size);
close(fd);
}
void drawDebugOverlayShow(const AutoGuidanceDebug& dbg, bool auto_fps) {
static auto last_show_time = std::chrono::steady_clock::now();
if (dbg.img_frame.src_img.empty())
return;
cv::Mat src_img = dbg.img_frame.src_img;
auto now = std::chrono::steady_clock::now();
const double min_interval_ms = 1000.0 / 30.0;
if (std::chrono::duration<double, std::milli>(now - last_show_time).count()
< min_interval_ms
&& auto_fps)
return;
last_show_time = now;
// 图像构造
cv::Mat debug_img;
src_img.convertTo(debug_img, -1, 1, 0);
cv::cvtColor(debug_img, debug_img, cv::COLOR_BGR2RGB);
if (debug_img.empty())
return;
// 封装后的绘图函数
drawAutoGuidanceDebugContent(debug_img, dbg);
cv::imshow("debug_armor", debug_img);
cv::waitKey(1);
}
void debuglog(const GuidanceTarget& target) {
static bool first_log = true;
static std::chrono::steady_clock::time_point start_time;
static DebugLogs log;
if (first_log) {
start_time = std::chrono::steady_clock::now();
first_log = false;
}
auto now = std::chrono::steady_clock::now();
double t = std::chrono::duration<double>(now - start_time).count();
log.time_log.push_back(t);
double cx_norm = target.center().x / target.image_size_.width * 2.0 - 1.0;
log.cx_log.push_back(cx_norm);
auto trim = [](std::vector<double>& v) {
if (v.size() > 1000)
v.erase(v.begin());
};
trim(log.time_log);
trim(log.cx_log);
nlohmann::json j;
{
j["time"] = log.time_log;
j["yaw"] = log.cx_log;
}
std::ofstream file("/dev/shm/cmd_log.json");
if (file.is_open()) {
file << j.dump();
}
}
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,24 @@
#pragma once
#include "tasks/auto_guidance/guidance_tracker/guidance_target.hpp"
#include "wust_vl/video/icamera.hpp"
namespace wust_vision {
namespace auto_guidance {
struct AutoGuidanceDebug {
wust_vl::video::ImageFrame img_frame;
double latency_ms;
GuidanceTarget target;
GreenLights lights;
cv::Mat mask;
};
struct DebugLogs {
std::vector<double> time_log;
std::vector<double> cx_log;
};
void debuglog(const GuidanceTarget& target);
void drawDebugOverlayShm(const AutoGuidanceDebug& dbg, bool auto_fps);
void drawDebugOverlayWrite(const AutoGuidanceDebug& dbg, bool auto_fps);
void drawDebugOverlayShow(const AutoGuidanceDebug& dbg, bool auto_fps);
void drawAutoGuidanceDebugContent(cv::Mat& debug_img, const AutoGuidanceDebug& dbg);
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,18 @@
#pragma once
#include "tasks/auto_guidance/type.hpp"
#include "tasks/type_common.hpp"
namespace wust_vision {
namespace auto_guidance {
class detector_base {
public:
virtual ~detector_base() = default;
virtual void pushInput(CommonFrame& frame) = 0;
using DetectorCallback =
std::function<void(const std::vector<GreenLight>&, const CommonFrame&)>;
virtual void setCallback(DetectorCallback cb) = 0;
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,28 @@
#pragma once
#include "tasks/auto_guidance/guidance_detector/opencv/guidance_detector_opencv.hpp"
#ifdef USE_OPENVINO
#include "openvino/guidance_detector_openvino.hpp"
#endif
namespace wust_vision {
namespace auto_guidance {
class DetectorFactory {
public:
static std::unique_ptr<detector_base>
createDetector(const std::string& backend, const YAML::Node& config, bool debug) {
#if defined(USE_OPENVINO)
if (backend == "openvino") {
return std::make_unique<GuidanceDetectorOpenVino>(config);
}
#endif
if (backend == "opencv") {
return std::make_unique<GuidanceDetectorOpenCV>(config, debug);
}
throw std::runtime_error("Unsupported detector backend (or not compiled): " + backend);
}
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,100 @@
#include "green_light_infer.hpp"
namespace wust_vision {
namespace auto_guidance {
GreenLightInfer::GreenLightInfer(const Params& params) {
params_ = params;
}
std::vector<GreenLight> GreenLightInfer::postProcess(
const cv::Mat& output_buffer,
const Eigen::Matrix<float, 3, 3>& transform_matrix
) {
std::vector<GreenLight> Lights;
std::vector<float> confidences;
std::vector<cv::Rect> boxes;
const int num_boxes = output_buffer.rows;
const int attr = output_buffer.cols;
for (int i = 0; i < num_boxes; ++i) {
float confidence = output_buffer.at<float>(i, 4);
if (confidence < params_.conf_threshold)
continue;
cv::Mat class_scores = output_buffer.row(i).colRange(5, 5 + 9);
cv::Mat color_scores = output_buffer.row(i).colRange(5 + 9, 5 + 9 + 4);
double maxClassConfidence;
cv::Point classIdPoint;
cv::minMaxLoc(class_scores, nullptr, &maxClassConfidence, nullptr, &classIdPoint);
if (maxClassConfidence < params_.conf_threshold)
continue;
if (classIdPoint.x != 8)
continue;
float cx = output_buffer.at<float>(i, 0);
float cy = output_buffer.at<float>(i, 1);
float w = output_buffer.at<float>(i, 2);
float h = output_buffer.at<float>(i, 3);
// === coordinate transform ===
Eigen::Vector3f pt(cx, cy, 1.0f);
Eigen::Vector3f pt_trans = transform_matrix * pt;
float cx_t = pt_trans(0);
float cy_t = pt_trans(1);
// compute scale for bbox
float scale_x = std::sqrt(transform_matrix.row(0).head<2>().squaredNorm());
float scale_y = std::sqrt(transform_matrix.row(1).head<2>().squaredNorm());
float w_t = w * scale_x;
float h_t = h * scale_y;
cv::Rect2d bbox(cx_t - w_t / 2.0f, cy_t - h_t / 2.0f, w_t, h_t);
GreenLight light;
light.id = classIdPoint.x;
light.score = confidence;
light.center_point = cv::Point2f(cx_t, cy_t);
light.box = bbox;
Lights.emplace_back(light);
confidences.emplace_back(confidence);
boxes.emplace_back(bbox);
}
std::vector<int> nms_result;
cv::dnn::NMSBoxes(
boxes,
confidences,
params_.conf_threshold,
params_.nms_threshold,
nms_result,
0.5f,
params_.top_k
);
auto IoU = [](const cv::Rect2d& a, const cv::Rect2d& b) {
double inter = (a & b).area();
double uni = a.area() + b.area() - inter;
return inter / uni;
};
std::vector<GreenLight> final_result;
for (int i = 0; i < nms_result.size(); i++) {
bool keep = true;
for (int j = 0; j < final_result.size(); j++) {
if (IoU(final_result[j].box, Lights[nms_result[i]].box) > 0.3) {
keep = false;
break;
}
}
if (keep)
final_result.push_back(Lights[nms_result[i]]);
}
return final_result;
}
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,35 @@
#pragma once
#include "tasks/auto_guidance/type.hpp"
namespace wust_vision {
namespace auto_guidance {
class GreenLightInfer {
public:
using GreenLightInferPtr = std::unique_ptr<GreenLightInfer>;
struct Params {
int input_w;
int input_h;
float conf_threshold;
float nms_threshold;
int top_k;
bool use_norm;
} params_;
GreenLightInfer(const Params& params);
static inline GreenLightInferPtr makeGreenLightInfer(const Params& params) {
return std::make_unique<GreenLightInfer>(params);
}
std::vector<GreenLight> postProcess(
const cv::Mat& output_buffer,
const Eigen::Matrix<float, 3, 3>& transform_matrix
);
int getInputW() {
return params_.input_w;
}
int getInputH() {
return params_.input_h;
}
bool getUseNorm() {
return params_.use_norm;
}
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,162 @@
#include "guidance_detector_opencv.hpp"
bool initializing = true;
int lowH = 35, highH = 85;
int lowS = 50, highS = 255;
int lowV = 80, highV = 255;
static void onTrackbar(int, void*) {
if (initializing)
return; // 初始化阶段不更新
lowH = cv::getTrackbarPos("LowH", "mask");
highH = cv::getTrackbarPos("HighH", "mask");
lowS = cv::getTrackbarPos("LowS", "mask");
highS = cv::getTrackbarPos("HighS", "mask");
lowV = cv::getTrackbarPos("LowV", "mask");
highV = cv::getTrackbarPos("HighV", "mask");
}
void initGUI() {
cv::namedWindow("mask", cv::WINDOW_NORMAL); // 允许调整大小
cv::resizeWindow("mask", 400, 600); // 设置初始窗口大小
cv::createTrackbar("LowH", "mask", nullptr, 179, onTrackbar);
cv::createTrackbar("HighH", "mask", nullptr, 179, onTrackbar);
cv::createTrackbar("LowS", "mask", nullptr, 255, onTrackbar);
cv::createTrackbar("HighS", "mask", nullptr, 255, onTrackbar);
cv::createTrackbar("LowV", "mask", nullptr, 255, onTrackbar);
cv::createTrackbar("HighV", "mask", nullptr, 255, onTrackbar);
cv::setTrackbarPos("LowH", "mask", lowH);
cv::setTrackbarPos("HighH", "mask", highH);
cv::setTrackbarPos("LowS", "mask", lowS);
cv::setTrackbarPos("HighS", "mask", highS);
cv::setTrackbarPos("LowV", "mask", lowV);
cv::setTrackbarPos("HighV", "mask", highV);
initializing = false;
}
namespace wust_vision {
namespace auto_guidance {
struct GuidanceDetectorOpenCV::Impl {
public:
Impl(const YAML::Node& config_gobal, bool debug) {
debug_ = debug;
const auto config = config_gobal["opencv"];
lowH = config["HSV"]["lowH"].as<int>();
highH = config["HSV"]["highH"].as<int>();
lowS = config["HSV"]["lowS"].as<int>();
highS = config["HSV"]["highS"].as<int>();
highV = config["HSV"]["highV"].as<int>();
lowV = config["HSV"]["lowV"].as<int>();
max_area_ = config["contours"]["max_area"].as<double>();
min_area_ = config["contours"]["min_area"].as<double>();
min_aspect_ratio = config["contours"]["min_aspect_ratio"].as<double>();
min_fill_ratio_ = config["contours"]["min_fill_ratio"].as<double>();
use_gui_ = config["gui"].as<bool>();
if (debug_ && use_gui_) {
initGUI();
}
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
void processCallback(const CommonFrame& frame) {
std::vector<GreenLight> lights;
cv::Mat img = frame.img_frame.src_img.clone();
cv::Mat hsv;
cv::cvtColor(img, hsv, cv::COLOR_BGR2HSV);
cv::Scalar lower_green(lowH, lowS, lowV);
cv::Scalar upper_green(highH, highS, highV);
cv::Mat mask;
cv::inRange(hsv, lower_green, upper_green, mask);
cv::Mat kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(5, 5));
cv::morphologyEx(mask, mask, cv::MORPH_OPEN, kernel);
cv::morphologyEx(mask, mask, cv::MORPH_CLOSE, kernel);
std::vector<std::vector<cv::Point>> contours;
cv::findContours(mask, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
std::vector<int> valid_indices;
for (size_t i = 0; i < contours.size(); i++) {
const double area = cv::contourArea(contours[i]);
const double perimeter = cv::arcLength(contours[i], true);
if (perimeter == 0)
continue;
const double circularity = 4 * CV_PI * area / (perimeter * perimeter);
const cv::RotatedRect rRect = cv::minAreaRect(contours[i]);
const double width = rRect.size.width;
const double height = rRect.size.height;
if (width <= 0 || height <= 0)
continue;
const double rect_area = width * height;
const double fill_ratio = area / rect_area;
const double aspect_ratio = std::min(width, height) / std::max(width, height);
if (area > min_area_ && area < max_area_ && fill_ratio > min_fill_ratio_
&& aspect_ratio > min_aspect_ratio)
{
cv::Point2f center;
float radius;
cv::minEnclosingCircle(contours[i], center, radius);
GreenLight gl;
gl.center_point = center;
gl.box = cv::boundingRect(contours[i]);
gl.score = circularity;
gl.radius = radius;
lights.push_back(gl);
}
}
static auto last = std::chrono::steady_clock::now();
const auto now = std::chrono::steady_clock::now();
const double dt = std::chrono::duration<double, std::milli>(now - last).count();
if (debug_ && dt > 33.3 && use_gui_) { // 30Hz 刷新
cv::imshow("mask", mask);
cv::waitKey(1); // 非阻塞
last = now;
}
if (infer_callback_) {
infer_callback_(lights, frame);
}
}
void pushInput(CommonFrame& frame) {
frame.id = current_id_++;
processCallback(frame);
}
DetectorCallback infer_callback_;
int current_id_ = 0;
double min_area_ = 100;
double max_area_ = 10000;
double min_fill_ratio_ = 0.5;
double min_aspect_ratio = 0.7;
bool debug_ = false;
bool use_gui_ = false;
};
GuidanceDetectorOpenCV::GuidanceDetectorOpenCV(const YAML::Node& config_gobal, bool debug) {
_impl = std::make_unique<Impl>(config_gobal, debug);
}
GuidanceDetectorOpenCV::~GuidanceDetectorOpenCV() {
_impl.reset();
}
void GuidanceDetectorOpenCV::pushInput(CommonFrame& frame) {
_impl->pushInput(frame);
}
void GuidanceDetectorOpenCV::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,18 @@
#pragma once
#include "tasks/auto_guidance/guidance_detector/detector_base.hpp"
namespace wust_vision {
namespace auto_guidance {
class GuidanceDetectorOpenCV: public detector_base {
public:
GuidanceDetectorOpenCV(const YAML::Node& config, bool debug);
~GuidanceDetectorOpenCV();
void pushInput(CommonFrame& frame) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,116 @@
#ifdef USE_OPENVINO
#include "guidance_detector_openvino.hpp"
#include "tasks/auto_guidance/guidance_detector/green_light_infer.hpp"
#include "tasks/utils/utils.hpp"
#include "wust_vl/ml_net/openvino/openvino_net.hpp"
namespace wust_vision {
namespace auto_guidance {
struct GuidanceDetectorOpenVino::Impl {
public:
Impl(const YAML::Node& config_gobal) {
auto config = config_gobal["openvino"];
std::string model_path = utils::expandEnv(config["model_path"].as<std::string>());
std::string device_name = config["device_type"].as<std::string>();
int top_k = config["top_k"].as<int>();
float nms_threshold = config["nms_threshold"].as<float>();
float conf_threshold = config["conf_threshold"].as<float>();
green_light_infer_ = GreenLightInfer::makeGreenLightInfer(GreenLightInfer::Params {
.input_w = 640,
.input_h = 384,
.conf_threshold = conf_threshold,
.nms_threshold = nms_threshold,
.top_k = top_k,
.use_norm = true });
openvino_net_ = std::make_unique<wust_vl::ml_net::OpenvinoNet>();
auto ppp_init_fun = [this](ov::preprocess::PrePostProcessor& ppp) {
ppp.input()
.tensor()
.set_element_type(ov::element::u8)
.set_layout("NHWC")
.set_color_format(ov::preprocess::ColorFormat::BGR);
ppp.input()
.preprocess()
.convert_element_type(ov::element::f32)
.convert_color(ov::preprocess::ColorFormat::RGB)
.scale(255.f);
ppp.input().model().set_layout("NCHW");
ppp.output(0).tensor().set_element_type(ov::element::f32);
ppp.output(1).tensor().set_element_type(ov::element::f32);
ppp.output(2).tensor().set_element_type(ov::element::f32);
};
wust_vl::ml_net::OpenvinoNet::Params params;
params.model_path = model_path;
params.device_name = device_name;
params.mode = config["use_throughputmode"].as<bool>()
? ov::hint::PerformanceMode::THROUGHPUT
: ov::hint::PerformanceMode::LATENCY;
openvino_net_->init(params, ppp_init_fun);
}
~Impl() {
openvino_net_.reset();
green_light_infer_.reset();
}
void setCallback(DetectorCallback callback) {
infer_callback_ = callback;
}
void processCallback(const CommonFrame& frame) const {
Eigen::Matrix3f transform_matrix;
const cv::Mat resized_img = utils::letterbox(
frame.img_frame.src_img,
transform_matrix,
green_light_infer_->getInputW(),
green_light_infer_->getInputH()
);
const auto input_info = openvino_net_->getInputInfo();
const auto input_tensor =
ov::Tensor(input_info.first, input_info.second, resized_img.data);
auto infer_request = openvino_net_->createInferRequest();
infer_request.set_input_tensor(input_tensor);
infer_request.infer();
const auto output = infer_request.get_output_tensor(0);
// Process output data
const auto output_shape = output.get_shape();
const float* ptr = output.data<const float>();
cv::Mat
output_buffer(output_shape[1], output_shape[2], CV_32F, const_cast<float*>(ptr));
const auto objs_result =
green_light_infer_->postProcess(output_buffer, transform_matrix);
if (infer_callback_) {
infer_callback_(objs_result, frame);
}
}
void pushInput(CommonFrame& frame) {
frame.id = current_id_++;
processCallback(frame);
}
std::unique_ptr<wust_vl::ml_net::OpenvinoNet> openvino_net_;
std::unique_ptr<GreenLightInfer> green_light_infer_;
DetectorCallback infer_callback_;
int current_id_ = 0;
};
GuidanceDetectorOpenVino::GuidanceDetectorOpenVino(const YAML::Node& config_gobal) {
_impl = std::make_unique<Impl>(config_gobal);
}
GuidanceDetectorOpenVino::~GuidanceDetectorOpenVino() {
_impl.reset();
}
void GuidanceDetectorOpenVino::pushInput(CommonFrame& frame) {
_impl->pushInput(frame);
}
void GuidanceDetectorOpenVino::setCallback(DetectorCallback callback) {
_impl->setCallback(callback);
}
} // namespace auto_guidance
} // namespace wust_vision
#endif

View File

@@ -0,0 +1,19 @@
#pragma once
#include "tasks/auto_guidance/guidance_detector/detector_base.hpp"
namespace wust_vision {
namespace auto_guidance {
class GuidanceDetectorOpenVino: public detector_base {
public:
GuidanceDetectorOpenVino(const YAML::Node& config);
~GuidanceDetectorOpenVino();
void pushInput(CommonFrame& frame) override;
void setCallback(DetectorCallback callback) override;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,143 @@
#include "guidance_target.hpp"
namespace wust_vision {
namespace auto_guidance {
GuidanceTarget::GuidanceTarget() {
target_state_ = Eigen::VectorXd::Zero(imgbox_model::X_N);
}
GuidanceTarget::GuidanceTarget(const GreenLight& light, TargetConfig target_config) {
target_config_ = target_config;
auto yfv2 = imgbox_model::Predict(0.01);
auto yhv2 = imgbox_model::Measure();
auto yu_qv2 = [this]() { return computeProcessNoise(0.01); };
auto yu_rv2 = [this](const Eigen::Matrix<double, imgbox_model::Z_N, 1>& z) {
return this->computeMeasurementCovariance(z);
};
Eigen::DiagonalMatrix<double, imgbox_model::X_N> p0;
p0.diagonal() << 1000, 1000, 1000, 1000, 64000, 64000, 64000, 64000;
esekf_ = imgbox_model::BBox8ESEKF(yfv2, yhv2, yu_qv2, yu_rv2, p0);
esekf_.setResidualFunc([this](
const Eigen::Matrix<double, imgbox_model::Z_N, 1>& z_pred,
const Eigen::Matrix<double, imgbox_model::Z_N, 1>& z
) {
Eigen::Matrix<double, imgbox_model::Z_N, 1> r = z - z_pred;
return r;
});
esekf_.setIterationNum(target_config_.iter_num);
esekf_.setInjectFunc([this](
const Eigen::Matrix<double, imgbox_model::X_N, 1>& delta,
Eigen::Matrix<double, imgbox_model::X_N, 1>& nominal
) {
for (int i = 0; i < imgbox_model::X_N; i++) {
nominal[i] += delta[i];
}
});
double cx = light.center_point.x;
double cy = light.center_point.y;
double w = light.box.width;
double h = light.box.height;
target_state_ << cx, 0, cy, 0, w, 0, h, 0;
esekf_.setState(target_state_);
last_t_ = light.timestamp;
position_ = light.position;
timestamp_ = light.timestamp;
image_size_ = light.image_size;
is_inited_ = true;
}
Eigen::Matrix<double, imgbox_model::Z_N, imgbox_model::Z_N>
GuidanceTarget::computeMeasurementCovariance(
const Eigen::Matrix<double, imgbox_model::Z_N, 1>& z
) const {
Eigen::Matrix<double, imgbox_model::Z_N, imgbox_model::Z_N> r;
// clang-format off
r <<target_config_.xy_r, 0, 0, 0,
0, target_config_.xy_r , 0, 0,
0, 0, target_config_.wh_r, 0,
0, 0, 0,target_config_.wh_r;
// clang-format on
return r;
}
Eigen::Matrix<double, imgbox_model::X_N, imgbox_model::X_N>
GuidanceTarget::computeProcessNoise(double dt) const {
Eigen::Matrix<double, imgbox_model::X_N, imgbox_model::X_N> q;
double t = dt;
double q_pp = pow(t, 4) / 4.0 * target_config_.q_xy;
double q_pv = pow(t, 3) / 2.0 * target_config_.q_xy;
double q_vv = pow(t, 2) * target_config_.q_xy;
double q_ss = pow(t, 4) / 4.0 * target_config_.q_wh;
double q_sv = pow(t, 3) / 2.0 * target_config_.q_wh;
double q_vvs = pow(t, 2) * target_config_.q_wh;
// clang-format off
// cx vx cy vy w vw h vh
q << q_pp, q_pv, 0, 0, 0, 0, 0, 0,
q_pv, q_vv, 0, 0, 0, 0, 0, 0,
0, 0, q_pp, q_pv, 0, 0, 0, 0,
0, 0, q_pv, q_vv, 0, 0, 0, 0,
0, 0, 0, 0, q_ss, q_sv, 0, 0,
0, 0, 0, 0, q_sv, q_vvs, 0, 0,
0, 0, 0, 0, 0, 0, q_ss, q_sv,
0, 0, 0, 0, 0, 0, q_sv, q_vvs;
// clang-format on
return q;
}
void GuidanceTarget::predict(std::chrono::steady_clock::time_point t) {
double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
predict(dt);
last_t_ = t;
}
void GuidanceTarget::predict(double dt) {
dt_ = dt;
esekf_.setPredictFunc(imgbox_model::Predict { dt });
auto yu_qv2 = [dt, this]() { return computeProcessNoise(dt); };
esekf_.setUpdateQ(yu_qv2);
target_state_ = esekf_.predict();
}
bool GuidanceTarget::update(const GreenLights& lights) {
auto ls = lights.lights;
timestamp_ = lights.timestamp;
auto yu_rv2 = [this](const Eigen::Matrix<double, imgbox_model::Z_N, 1>& z) {
return this->computeMeasurementCovariance(z);
};
esekf_.setUpdateR(yu_rv2);
int best_id = -1;
double min_error = std::numeric_limits<double>::max();
for (int i = 0; i < ls.size(); i++) {
double centor_error = cv::norm(ls[i].center_point - center());
double pos_error = (ls[i].position - position_).norm();
if (centor_error < min_error && pos_error < target_config_.max_dis_diff) {
min_error = centor_error;
best_id = i;
}
}
if (best_id == -1) {
return false;
}
measurement_ = Eigen::Vector4d(
ls[best_id].center_point.x,
ls[best_id].center_point.y,
ls[best_id].box.width,
ls[best_id].box.height
);
esekf_.setMeasureFunc(imgbox_model::Measure());
target_state_ = esekf_.update(measurement_);
position_ = ls[best_id].position;
image_size_ = ls[best_id].image_size;
last_t_ = timestamp_;
return true;
}
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,89 @@
#pragma once
#include "motion_models/imgbox_model.hpp"
#include "tasks/auto_guidance/type.hpp"
#include <wust_vl/common/utils/timer.hpp>
#include <yaml-cpp/yaml.h>
namespace wust_vision {
namespace auto_guidance {
struct TargetConfig {
void load(const YAML::Node& config) {
xy_r = config["xy_r"].as<double>();
wh_r = config["wh_r"].as<double>();
q_xy = config["q_xy"].as<double>();
q_wh = config["q_wh"].as<double>();
iter_num = config["iter_num"].as<int>();
max_dis_diff = config["max_dis_diff"].as<double>();
}
double xy_r = 0.05;
double wh_r = 0.05;
double q_xy = 10;
double q_wh = 10;
int iter_num = 2;
double max_dis_diff = 2.0;
};
class GuidanceTarget {
public:
GuidanceTarget();
GuidanceTarget(const GreenLight& light, TargetConfig target_config);
GuidanceTarget& operator=(const GuidanceTarget&) = default;
bool update(const GreenLights& lights);
void predict(std::chrono::steady_clock::time_point t);
void predict(double dt);
Eigen::Matrix<double, imgbox_model::Z_N, imgbox_model::Z_N>
computeMeasurementCovariance(const Eigen::Matrix<double, imgbox_model::Z_N, 1>& z) const;
Eigen::Matrix<double, imgbox_model::X_N, imgbox_model::X_N> computeProcessNoise(double dt
) const;
std::chrono::steady_clock::time_point last_t_;
std::chrono::steady_clock::time_point timestamp_;
double dt_;
cv::Size2d image_size_;
imgbox_model::BBox8ESEKF esekf_;
Eigen::Matrix<double, imgbox_model::Z_N, 1> measurement_ =
Eigen::Matrix<double, imgbox_model::Z_N, 1>::Zero();
Eigen::Matrix<double, imgbox_model::X_N, 1> target_state_ =
Eigen::Matrix<double, imgbox_model::X_N, 1>::Zero();
Eigen::Vector3d position_;
TargetConfig target_config_;
bool is_inited_ = false;
bool is_tracking_ = false;
bool checkappear() {
return is_tracking_
&& wust_vl::common::utils::time_utils::durationSec(
timestamp_,
wust_vl::common::utils::time_utils::now()
)
< 3.0;
}
cv::Point2d center() const {
return cv::Point2d(target_state_(0), target_state_(2));
}
cv::Rect2d box() const {
return cv::Rect2d(
target_state_(0) - target_state_(4) / 2,
target_state_(2) - target_state_(6) / 2,
target_state_(4),
target_state_(6)
);
}
void draw(cv::Mat& img) const {
cv::rectangle(img, box(), cv::Scalar(255, 50, 0), 2);
cv::circle(img, center(), 3, cv::Scalar(255, 255, 255), -1);
cv::line(
img,
cv::Point(center().x, center().y),
cv::Point(img.cols / 2.0, center().y),
cv::Scalar(0, 0, 255),
2
);
cv::line(
img,
cv::Point2f(img.cols / 2.0, 0),
cv::Point2f(img.cols / 2.0, img.rows),
cv::Scalar(255, 255, 255),
2
);
}
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,109 @@
#include "guidance_tracker.hpp"
namespace wust_vision {
namespace auto_guidance {
struct GuidanceTracker::Impl {
public:
Impl(const YAML::Node& config) {
target_config_.load(config["target"]);
tracking_thres_ = config["tracking_thres"].as<int>(5);
lost_dt_ = config["lost_time_thres"].as<double>();
}
GuidanceTarget track(const GreenLights& lights) {
double dt = std::chrono::duration<double>(lights.timestamp - last_time_).count();
last_time_ = lights.timestamp;
lost_thres_ = std::abs(static_cast<int>(lost_dt_ / dt));
bool found;
if (tracker_state == LOST) {
found = initTarget(lights);
} else {
found = updateTarget(lights);
}
updateFsm(found);
return guidance_target_;
}
void updateFsm(bool found) {
if (tracker_state == DETECTING) {
if (found) {
detect_count_++;
if (detect_count_ > tracking_thres_) {
detect_count_ = 0;
tracker_state = TRACKING;
}
} else {
detect_count_ = 0;
tracker_state = LOST;
}
} else if (tracker_state == TRACKING) {
if (!found) {
tracker_state = TEMP_LOST;
lost_count_++;
}
} else if (tracker_state == TEMP_LOST) {
if (!found) {
lost_count_++;
if (lost_count_ > lost_thres_) {
lost_count_ = 0;
tracker_state = LOST;
}
} else {
tracker_state = TRACKING;
lost_count_ = 0;
}
}
if (tracker_state == LOST || tracker_state == DETECTING) {
guidance_target_.is_tracking_ = false;
} else {
guidance_target_.is_tracking_ = true;
}
}
bool initTarget(const GreenLights& lights) {
int best_id = -1;
double max_score = -1e9;
for (int i = 0; i < lights.lights.size(); i++) {
if (lights.lights[i].score > max_score) {
max_score = lights.lights[i].score;
best_id = i;
}
}
if (best_id == -1) {
return false;
}
tracker_state = DETECTING;
guidance_target_ = GuidanceTarget(lights.lights[best_id], target_config_);
return true;
}
bool updateTarget(const GreenLights& lights) {
guidance_target_.predict(lights.timestamp);
return guidance_target_.update(lights);
}
enum State {
LOST,
DETECTING,
TRACKING,
TEMP_LOST,
} tracker_state = LOST;
GuidanceTarget guidance_target_;
int tracking_thres_;
int lost_thres_;
int detect_count_ = 0;
int lost_count_ = 0;
double lost_dt_;
std::chrono::steady_clock::time_point last_time_;
TargetConfig target_config_;
};
GuidanceTracker::GuidanceTracker(const YAML::Node& config) {
_impl = std::make_unique<Impl>(config);
}
GuidanceTracker::~GuidanceTracker() {
_impl.reset();
}
GuidanceTarget GuidanceTracker::track(const GreenLights& lights) {
return _impl->track(lights);
}
std::chrono::steady_clock::time_point GuidanceTracker::getLastTime() const {
return _impl->last_time_;
}
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,21 @@
#pragma once
#include "tasks/auto_guidance/guidance_tracker/guidance_target.hpp"
namespace wust_vision {
namespace auto_guidance {
class GuidanceTracker {
public:
using Ptr = std::unique_ptr<GuidanceTracker>;
GuidanceTracker(const YAML::Node& config);
~GuidanceTracker();
static inline Ptr create(const YAML::Node& config) {
return std::make_unique<GuidanceTracker>(config);
}
GuidanceTarget track(const GreenLights& lights);
std::chrono::steady_clock::time_point getLastTime() const;
private:
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,45 @@
#pragma once
#include "KalmanHyLib/kalman_hybird_lib.hpp"
#include <ceres/ceres.h>
namespace imgbox_model {
static constexpr int X_N = 8; // cx vx cy vy w vw h vh
static constexpr int Z_N = 4; // measured cx cy w h
// ========================== Predict Model ==========================
struct Predict {
Predict() = default;
explicit Predict(double dt): dt_(dt) {}
template<typename T>
void operator()(const T x0[X_N], T x1[X_N]) {
for (int i = 0; i < X_N; i++) {
x1[i] = x0[i];
}
x1[0] += x0[1] * dt_; // cx
x1[2] += x0[3] * dt_; // cy
x1[4] += x0[5] * dt_; // w
x1[6] += x0[7] * dt_; // h
}
double dt_;
};
// ========================== Measurement Model ==========================
struct Measure {
Measure() = default;
template<typename T>
void operator()(const T x[X_N], T z[Z_N]) const {
z[0] = x[0]; // cx
z[1] = x[2]; // cy
z[2] = x[4]; // w
z[3] = x[6]; // h
}
};
using BBox8EKF = kalman_hybird_lib::ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
using BBox8ESEKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
} // namespace imgbox_model

View File

@@ -0,0 +1,107 @@
#pragma once
#include "Eigen/Dense"
#include "opencv2/opencv.hpp"
namespace wust_vision {
namespace auto_guidance {
struct GreenLight {
int id = -1;
double score = 0.;
cv::Rect2d box; // bounding box in pixel coordinates
cv::Point2d center_point; // center in pixel coordinates
double radius;
Eigen::Vector3d position;
std::chrono::steady_clock::time_point timestamp;
cv::Size2d image_size;
// PnP 估计位姿
bool solvePnP(const cv::Mat& K, const cv::Mat& distCoeffs) {
constexpr float half_w = 0.07;
// 真实世界点,单位米
std::vector<cv::Point3f> objectPoints = { { -half_w, -half_w, 0.f },
{ half_w, -half_w, 0.f },
{ half_w, half_w, 0.f },
{ -half_w, half_w, 0.f } };
// 像素点
std::vector<cv::Point2f> imagePoints = {
cv::Point2f(box.x, box.y),
cv::Point2f(box.x + box.width, box.y),
cv::Point2f(box.x + box.width, box.y + box.height),
cv::Point2f(box.x, box.y + box.height)
};
cv::Mat rvec, tvec;
bool ok = cv::solvePnP(
objectPoints,
imagePoints,
K,
distCoeffs,
rvec,
tvec,
false,
cv::SOLVEPNP_ITERATIVE
);
if (!ok)
return false;
// 转换到 Eigen 向量
position = Eigen::Vector3d(tvec.at<double>(0), tvec.at<double>(1), tvec.at<double>(2));
return true;
}
};
struct GreenLights {
public:
std::vector<GreenLight> lights;
std::chrono::steady_clock::time_point timestamp;
int id;
void drawFront(cv::Mat& img) const {
for (const auto& light: lights) {
cv::rectangle(img, light.box, cv::Scalar(0, 255, 255), 2);
cv::circle(img, light.center_point, light.radius, cv::Scalar(0, 255, 0), 2);
cv::circle(img, light.center_point, 3, cv::Scalar(255, 0, 0), -1);
cv::putText(
img,
std::to_string(light.score),
light.center_point
+ cv::Point2d(light.box.width / 2.0, -light.box.height / 2.0),
cv::FONT_HERSHEY_SIMPLEX,
0.5,
cv::Scalar(255, 0, 0),
2
);
cv::putText(
img,
std::to_string(light.position.norm()),
light.center_point + cv::Point2d(light.box.width / 2.0, light.box.height / 2.0),
cv::FONT_HERSHEY_SIMPLEX,
0.5,
cv::Scalar(255, 255, 255),
2
);
}
}
void drawBack(cv::Mat& img) const {
for (const auto& light: lights) {
cv::line(
img,
cv::Point(light.center_point.x, light.center_point.y),
cv::Point(img.cols / 2.0, light.center_point.y),
cv::Scalar(0, 0, 255),
2
);
}
cv::line(
img,
cv::Point2f(img.cols / 2.0, 0),
cv::Point2f(img.cols / 2.0, img.rows),
cv::Scalar(255, 255, 255),
2
);
}
};
} // namespace auto_guidance
} // namespace wust_vision

View File

@@ -0,0 +1,344 @@
#include "3rdparty/angles.h"
//#ifdef USE_ROS2
#ifdef FUCK
#include "auto_sniper.hpp"
#include "ros2/tf.hpp"
#include <Eigen/src/Core/Matrix.h>
#include <atomic>
#include <memory>
#include <mutex>
#include <open3d/utility/Eigen.h>
#include <optional>
#include <rclcpp/logger.hpp>
#include <rclcpp/logging.hpp>
#include <thread>
#include <vector>
#include <Eigen/Dense>
#include <nav_msgs/msg/odometry.hpp>
#include <open3d/Open3D.h>
#include <rclcpp/node.hpp>
#include <rclcpp/rclcpp.hpp>
#include "k1_solver.hpp"
#include "offset_helper.hpp"
#include "tasks/type_common.hpp"
#include "tasks/utils/config.hpp"
#include "voxel_map.hpp"
#include <sensor_msgs/msg/point_cloud2.hpp>
#include <sensor_msgs/point_cloud2_iterator.hpp>
#include <visualization_msgs/msg/marker.hpp>
#include <yaml-cpp/yaml.h>
namespace wust_vision::auto_sniper {
struct AutoSniper::Impl {
Impl(
rclcpp::Node& node,
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<Motion, 1024>> motion_buffer
) {
node_ = &node;
motion_buffer_ = motion_buffer;
tf_ = TF::create(*node_);
auto config = YAML::LoadFile(AUTO_SNIPER_CONFIG);
auto map_config = config["map"];
auto min_pos_v = map_config["min_pos"].as<std::vector<double>>();
auto min_pos = Eigen::Vector3d(min_pos_v[0], min_pos_v[1], min_pos_v[2]);
auto max_pos_v = map_config["max_pos"].as<std::vector<double>>();
auto max_pos = Eigen::Vector3d(max_pos_v[0], max_pos_v[1], max_pos_v[2]);
voxel_map_ = std::make_shared<SlidingVoxelMap<3, Cell>>(
map_config["voxel_size"].as<double>(),
min_pos,
max_pos,
true
);
auto solver_config = config["solver"];
vis_cloud_ = std::make_shared<open3d::geometry::PointCloud>();
k1_solver_ = K1BallisticSolver::create(
solver_config["k1"].as<double>(),
solver_config["g"].as<double>()
);
target_armor_z_ = solver_config["target_armor_z"].as<double>();
offset_helper_ = OffsetHelper::create(config["offset_helper"]);
pointcloud_sub_ = node_->create_subscription<sensor_msgs::msg::PointCloud2>(
"/cloud_registered",
rclcpp::SensorDataQoS(),
std::bind(&AutoSniper::Impl::pointCloudCallback, this, std::placeholders::_1)
);
odometry_sub_ = node_->create_subscription<nav_msgs::msg::Odometry>(
"/Odometry",
10,
std::bind(&AutoSniper::Impl::odomCallback, this, std::placeholders::_1)
);
traj_pub_ =
node_->create_publisher<visualization_msgs::msg::Marker>("bullet_trajectory", 10);
}
void start() {
if (run_flag_) {
return;
}
run_flag_ = true;
vis_thread_ = wust_vl::common::concurrency::MonitoredThread::create(
"AutoSniperVisualizer",
[this](wust_vl::common::concurrency::MonitoredThread::Ptr self) {
this->visualizeLoop(self);
}
);
wust_vl::common::concurrency::ThreadManager::instance().registerThread(vis_thread_);
}
void doDebug() {}
void pushInput(CommonFrame& frame) {}
wust_vl::common::concurrency::MonitoredThread::Ptr getThread() {
return vis_thread_;
}
GimbalCmd solve(double bullet_speed) noexcept {
GimbalCmd cmd;
if (!target_pos_in_map_.has_value()) {
cmd.appear = false;
return cmd;
}
Eigen::Vector3d target_pos_in_self = self_in_map_.inverse() * target_pos_in_map_.value();
target_pos_in_self.z() = target_armor_z_;
auto pitch = k1_solver_->solvePitch(target_pos_in_self, bullet_speed);
if (!pitch.has_value()) {
cmd.appear = false;
std::cout << "no pitch" << std::endl;
return cmd;
}
double yaw =
angles::normalize_angle(std::atan2(target_pos_in_self.y(), target_pos_in_self.x()));
auto traj = k1_solver_->computeTrajectory(
self_in_map_.translation(),
target_pos_in_map_.value(),
bullet_speed,
0.01
);
double yaw_deg = angles::to_degrees(yaw);
double pitch_deg = angles::to_degrees(pitch.value());
publishTrajectoryMarker(traj);
auto control_pitch = pitch_deg + offset_helper_->getPitchOffset(target_pos_in_self.norm());
auto control_yaw = yaw_deg + offset_helper_->getYawOffset(target_pos_in_self.norm());
if (auto last_att = motion_buffer_->get_last()) {
control_yaw += angles::to_degrees(last_att->data.yaw);
}
cmd.appear = true;
cmd.target_pitch = control_pitch;
cmd.target_yaw = control_yaw;
cmd.appear = true;
cmd.yaw = control_yaw;
cmd.pitch = control_pitch;
cmd.distance = target_pos_in_self.norm();
cmd.enable_pitch_diff = 0.5;
cmd.enable_yaw_diff = 0.5;
static int count = 0;
count++;
if (count % 100 == 0) {
std::cout << cmd.yaw << " " << cmd.pitch << std::endl;
}
return cmd;
}
void publishTrajectoryMarker(const std::vector<Eigen::Vector3d>& traj) {
static auto last_pub_time = std::chrono::steady_clock::now();
auto now = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(now - last_pub_time);
if (elapsed.count() < 33) {
return;
}
last_pub_time = now;
if (traj.empty())
return;
visualization_msgs::msg::Marker marker;
marker.header.frame_id = target_frame_;
marker.header.stamp = node_->now();
marker.ns = "bullet";
marker.id = 0;
marker.type = visualization_msgs::msg::Marker::LINE_STRIP;
marker.action = visualization_msgs::msg::Marker::ADD;
marker.scale.x = 0.1;
marker.color.r = 0.0;
marker.color.g = 1.0;
marker.color.b = 1.0;
marker.color.a = 1.0;
marker.lifetime = rclcpp::Duration::from_seconds(1.0);
for (auto& p: traj) {
geometry_msgs::msg::Point pt;
pt.x = p.x();
pt.y = p.y();
pt.z = p.z();
marker.points.push_back(pt);
}
traj_pub_->publish(marker);
}
void odomCallback(const nav_msgs::msg::Odometry::SharedPtr msg) {
auto T = tf_->getTransform(target_frame_, "gimbal_yaw", msg->header.stamp);
if (!T.has_value()) {
return;
}
self_in_map_ = T.value().cast<double>();
// Eigen::Isometry3f
// Eigen::Vector4f p(
// msg->pose.pose.position.x,
// msg->pose.pose.position.y,
// msg->pose.pose.position.z,
// 1.0f
// );
// auto p_target = T.value() * p;
// self_pos_ = Eigen::Vector3d(p_target.x(), p_target.y(), p_target.z());
}
void visualizeLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) {
while (!self->isAlive()) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
auto vis_cloud = std::make_shared<open3d::geometry::PointCloud>();
std::atomic<bool> picking = false;
auto future = std::async(std::launch::async, [&, self]() {
while (self->isAlive() && run_flag_) {
picking = true;
pickBlocking(vis_cloud);
picking = false;
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
});
while (self->isAlive() && run_flag_) {
self->heartbeat();
if (!picking) {
std::lock_guard<std::mutex> lock(cloud_mutex_);
*vis_cloud = *vis_cloud_;
}
std::this_thread::sleep_for(std::chrono::milliseconds(30));
}
if (future.valid()) {
future.wait();
}
}
void pickBlocking(std::shared_ptr<open3d::geometry::PointCloud> pcd) {
open3d::visualization::VisualizerWithEditing vis;
vis.CreateVisualizerWindow("Pick points (close window when done)", 1280, 720);
vis.AddGeometry(pcd);
vis.Run();
auto picked = vis.GetPickedPoints();
if (!picked.empty()) {
Eigen::Vector3d picked_pos = pcd->points_[picked.back()];
target_pos_in_map_ = Eigen::Vector3d(picked_pos);
RCLCPP_INFO_STREAM(
rclcpp::get_logger("awm"),
" Target pos: " << picked_pos.transpose()
<< " self pos: " << self_in_map_.translation().transpose()
);
}
vis.DestroyVisualizerWindow();
}
void pointCloudCallback(const sensor_msgs::msg::PointCloud2::SharedPtr msg) {
auto T = tf_->getTransform(target_frame_, msg->header.frame_id, msg->header.stamp);
if (!T.has_value()) {
return;
}
const size_t size = msg->width * msg->height;
sensor_msgs::PointCloud2ConstIterator<float> iter_x(*msg, "x");
sensor_msgs::PointCloud2ConstIterator<float> iter_y(*msg, "y");
sensor_msgs::PointCloud2ConstIterator<float> iter_z(*msg, "z");
std::vector<Eigen::Vector3d> new_points;
new_points.reserve(size);
for (size_t i = 0; i < size; ++i) {
Eigen::Vector4f p(*iter_x, *iter_y, *iter_z, 1.0f);
p = T.value() * p;
if (std::isfinite(p.x()) && std::isfinite(p.y()) && std::isfinite(p.z())) {
new_points.emplace_back(p.head<3>().cast<double>());
}
++iter_x;
++iter_y;
++iter_z;
}
for (const auto& p: new_points) {
auto idx = voxel_map_->worldToIndex(p);
if (idx > 0) {
voxel_map_->grid[idx].v = 1;
}
}
std::vector<Eigen::Vector3d> pointcloud;
for (int i = 0; i < voxel_map_->grid.size(); i++) {
if (voxel_map_->grid[i].v == 1) {
pointcloud.emplace_back(voxel_map_->indexToWorld(i));
}
}
{
std::lock_guard<std::mutex> lock(cloud_mutex_);
vis_cloud_->points_ = pointcloud;
}
}
std::string target_frame_ = "map";
rclcpp::Node* node_;
std::optional<Eigen::Vector3d> target_pos_in_map_ = std::nullopt;
rclcpp::Subscription<sensor_msgs::msg::PointCloud2>::SharedPtr pointcloud_sub_;
rclcpp::Publisher<visualization_msgs::msg::Marker>::SharedPtr traj_pub_;
rclcpp::Subscription<nav_msgs::msg::Odometry>::SharedPtr odometry_sub_;
TF::Ptr tf_;
std::shared_ptr<open3d::geometry::PointCloud> vis_cloud_;
struct Cell {
uint8_t v = 0;
};
SlidingVoxelMap<3, Cell>::Ptr voxel_map_;
K1BallisticSolver::Ptr k1_solver_;
OffsetHelper::Ptr offset_helper_;
wust_vl::common::concurrency::MonitoredThread::Ptr vis_thread_;
bool run_flag_ = false;
std::mutex cloud_mutex_;
double target_armor_z_ = 0.0;
Eigen::Isometry3d self_in_map_ = Eigen::Isometry3d::Identity();
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<Motion, 1024>> motion_buffer_;
};
AutoSniper::AutoSniper(
rclcpp::Node& node,
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<Motion, 1024>> motion_buffer
) {
_impl = std::make_unique<Impl>(node, motion_buffer);
}
AutoSniper::~AutoSniper() {
_impl.reset();
}
void AutoSniper::start() {
_impl->start();
}
void AutoSniper::doDebug() {
_impl->doDebug();
}
void AutoSniper::pushInput(CommonFrame& frame) {
_impl->pushInput(frame);
}
wust_vl::common::concurrency::MonitoredThread::Ptr AutoSniper::getThread() {
return _impl->getThread();
}
GimbalCmd AutoSniper::solve(double bullet_speed) {
return _impl->solve(bullet_speed);
}
} // namespace wust_vision::auto_sniper
#endif

View File

@@ -0,0 +1,32 @@
#pragma once
#include "tasks/imodule.hpp"
#include "tasks/type_common.hpp"
#include <memory>
#include <rclcpp/node.hpp>
namespace wust_vision {
namespace auto_sniper {
class AutoSniper: public IModule {
public:
using Ptr = std::shared_ptr<AutoSniper>;
AutoSniper(
rclcpp::Node& node,
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<Motion, 1024>> motion_buffer
);
static Ptr create(
rclcpp::Node& node,
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<Motion, 1024>> motion_buffer
) {
return std::make_shared<AutoSniper>(node, motion_buffer);
}
~AutoSniper();
void start() override;
void doDebug() override;
void pushInput(CommonFrame& frame) override;
GimbalCmd solve(double bullet_speed) override;
wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override;
struct Impl;
std::unique_ptr<Impl> _impl;
};
} // namespace auto_sniper
} // namespace wust_vision

View File

@@ -0,0 +1,106 @@
#include <Eigen/Dense>
#include <cmath>
#include <memory>
#include <optional>
class K1BallisticSolver {
public:
using Ptr = std::unique_ptr<K1BallisticSolver>;
K1BallisticSolver(double k1 = 0.05, double g = 9.81): k1_(k1), g_(g) {}
static Ptr create(double k1 = 0.05, double g = 9.81) {
return std::make_unique<K1BallisticSolver>(k1, g);
}
std::optional<double> solvePitch(const Eigen::Vector3d& target_pos, double v0) const {
double x = std::hypot(target_pos.x(), target_pos.y());
double z = target_pos.z();
if (x < 1e-6)
return std::nullopt;
auto heightError = [&](double pitch) {
double cos_theta = std::cos(pitch);
double sin_theta = std::sin(pitch);
double denom = v0 * cos_theta;
if (denom <= 1e-6)
return 1e6;
double t = -std::log(1.0 - k1_ * x / denom) / k1_;
if (!std::isfinite(t))
return 1e6;
double z_pred =
((v0 * sin_theta + g_ / k1_) / k1_) * (1.0 - std::exp(-k1_ * t)) - (g_ / k1_) * t;
return z_pred - z;
};
double left = -0.3;
double right = 0.6;
for (int i = 0; i < 60; ++i) {
double mid = 0.5 * (left + right);
double err = heightError(mid);
if (err > 0)
right = mid;
else
left = mid;
}
return 0.5 * (left + right);
}
std::vector<Eigen::Vector3d> computeTrajectory(
const Eigen::Vector3d& start,
const Eigen::Vector3d& target,
double v0,
double dt = 0.01 // 每个离散点的时间间隔
) {
std::vector<Eigen::Vector3d> traj;
Eigen::Vector3d diff = target - start;
auto pitch_opt = solvePitch(diff, v0);
if (!pitch_opt.has_value())
return traj;
double pitch = pitch_opt.value();
double yaw = std::atan2(diff.y(), diff.x());
double k1 = k1_;
double g = g_;
double vx = v0 * std::cos(pitch) * std::cos(yaw);
double vy = v0 * std::cos(pitch) * std::sin(yaw);
double vz = v0 * std::sin(pitch);
double t = 0.0;
Eigen::Vector3d pos = start;
while (true) {
double exp_kt = std::exp(-k1 * t);
pos.x() = start.x() + (vx / k1) * (1 - exp_kt);
pos.y() = start.y() + (vy / k1) * (1 - exp_kt);
pos.z() = start.z() + ((vz + g / k1) / k1) * (1 - exp_kt) - (g / k1) * t;
traj.push_back(pos);
t += dt;
double dx = std::abs(pos.x() - start.x());
double dy = std::abs(pos.y() - start.y());
double horizontal_dist = std::sqrt(dx * dx + dy * dy);
double target_dist = std::sqrt(diff.x() * diff.x() + diff.y() * diff.y());
if (horizontal_dist >= target_dist) {
break;
}
}
return traj;
}
private:
double k1_;
double g_;
};

View File

@@ -0,0 +1,87 @@
#pragma once
#include <Eigen/Dense>
#include <vector>
#include <yaml-cpp/node/node.h>
#include <yaml-cpp/yaml.h>
namespace wust_vision::auto_sniper {
class OffsetHelper {
public:
using Ptr = std::shared_ptr<OffsetHelper>;
struct OffsetPoint {
double distance;
double yaw;
double pitch;
};
OffsetHelper(const YAML::Node& config) {
data_.clear();
for (auto& v: config["offset_table"]) {
OffsetPoint p;
p.distance = v["distance"].as<double>();
p.yaw = v["yaw"].as<double>();
p.pitch = v["pitch"].as<double>();
data_.push_back(p);
}
order_ = config["order"].as<int>();
yaw_base_offset = config["yaw_base_offset"].as<double>();
pitch_base_offset = config["pitch_base_offset"].as<double>();
fit();
}
static Ptr create(const YAML::Node& config) {
return std::make_shared<OffsetHelper>(config);
}
void fit() {
int n = data_.size();
Eigen::MatrixXd A(n, order_ + 1);
Eigen::VectorXd y_yaw(n);
Eigen::VectorXd y_pitch(n);
for (int i = 0; i < n; ++i) {
double x = data_[i].distance;
double v = 1;
for (int j = 0; j <= order_; ++j) {
A(i, j) = v;
v *= x;
}
y_yaw(i) = data_[i].yaw;
y_pitch(i) = data_[i].pitch;
}
yaw_coeff_ = A.colPivHouseholderQr().solve(y_yaw);
pitch_coeff_ = A.colPivHouseholderQr().solve(y_pitch);
}
double getYawOffset(double distance) const {
return yaw_base_offset + eval(yaw_coeff_, distance);
}
double getPitchOffset(double distance) const {
return pitch_base_offset + eval(pitch_coeff_, distance);
}
double eval(const Eigen::VectorXd& coeff, double x) const {
double y = 0;
double v = 1;
for (int i = 0; i < coeff.size(); ++i) {
y += coeff[i] * v;
v *= x;
}
return y;
}
std::vector<OffsetPoint> data_;
Eigen::VectorXd yaw_coeff_;
Eigen::VectorXd pitch_coeff_;
double yaw_base_offset = 0;
double pitch_base_offset = 0;
int order_ = 2;
};
} // namespace wust_vision::auto_sniper

View File

@@ -0,0 +1,348 @@
#pragma once
#include "3rdparty/ankerl/unordered_dense.h"
#include <Eigen/Dense>
#include <cmath>
#include <memory>
#include <vector>
namespace wust_vision::auto_sniper {
template<int Dim>
struct VoxelKey {
std::array<int, Dim> data {};
int& operator[](int i) noexcept {
return data[i];
}
const int& operator[](int i) const noexcept {
return data[i];
}
// ----- x -----
int& x() noexcept {
static_assert(Dim >= 1, "x requires Dim >= 1");
return data[0];
}
const int& x() const noexcept {
static_assert(Dim >= 1, "x requires Dim >= 1");
return data[0];
}
// ----- y -----
int& y() noexcept {
static_assert(Dim >= 2, "y requires Dim >= 2");
return data[1];
}
const int& y() const noexcept {
static_assert(Dim >= 2, "y requires Dim >= 2");
return data[1];
}
// ----- z -----
int& z() noexcept {
static_assert(Dim >= 3, "z requires Dim >= 3");
return data[2];
}
const int& z() const noexcept {
static_assert(Dim >= 3, "z requires Dim >= 3");
return data[2];
}
bool operator==(const VoxelKey& other) const noexcept {
for (int i = 0; i < Dim; ++i)
if (data[i] != other.data[i])
return false;
return true;
}
bool operator!=(const VoxelKey& other) const noexcept {
return !(*this == other);
}
VoxelKey& operator+=(const VoxelKey& other) noexcept {
for (int i = 0; i < Dim; ++i)
data[i] += other.data[i];
return *this;
}
VoxelKey& operator-=(const VoxelKey& other) noexcept {
for (int i = 0; i < Dim; ++i)
data[i] -= other.data[i];
return *this;
}
VoxelKey& operator*=(int scalar) noexcept {
for (int i = 0; i < Dim; ++i)
data[i] *= scalar;
return *this;
}
VoxelKey& operator/=(int scalar) noexcept {
for (int i = 0; i < Dim; ++i)
data[i] /= scalar;
return *this;
}
friend VoxelKey operator+(VoxelKey lhs, const VoxelKey& rhs) noexcept {
lhs += rhs;
return lhs;
}
friend VoxelKey operator-(VoxelKey lhs, const VoxelKey& rhs) noexcept {
lhs -= rhs;
return lhs;
}
friend VoxelKey operator*(VoxelKey lhs, int scalar) noexcept {
lhs *= scalar;
return lhs;
}
friend VoxelKey operator*(int scalar, VoxelKey rhs) noexcept {
rhs *= scalar;
return rhs;
}
friend VoxelKey operator/(VoxelKey lhs, int scalar) noexcept {
lhs /= scalar;
return lhs;
}
constexpr VoxelKey cwiseMin(const VoxelKey& other) const noexcept {
VoxelKey out {};
for (int i = 0; i < Dim; ++i)
out.data[i] = data[i] < other.data[i] ? data[i] : other.data[i];
return out;
}
constexpr VoxelKey cwiseMax(const VoxelKey& other) const noexcept {
VoxelKey out {};
for (int i = 0; i < Dim; ++i)
out.data[i] = data[i] > other.data[i] ? data[i] : other.data[i];
return out;
}
};
template<int Dim, typename Cell>
class SlidingVoxelMap {
static_assert(Dim == 2 || Dim == 3, "Dim must be 2 or 3");
public:
using Ptr = std::shared_ptr<SlidingVoxelMap>;
using Key = VoxelKey<Dim>;
using EigenPoint = Eigen::Matrix<double, Dim, 1>;
SlidingVoxelMap(double voxel_size_, const EigenPoint& size_, const EigenPoint& center_):
voxel_size(voxel_size_),
size(size_),
center(center_) {
EigenPoint half = size * 0.5f;
min_key = worldToKey(center - half);
max_key = worldToKey(center + half);
for (int i = 0; i < Dim; ++i) {
dims[i] = max_key[i] - min_key[i] + 1;
offset[i] = 0;
}
center_key = worldToKey(center);
size_t N = 1;
for (int i = 0; i < Dim; ++i)
N *= static_cast<size_t>(dims[i]);
grid.resize(N);
}
SlidingVoxelMap(
double voxel_size_,
const EigenPoint& min_pos,
const EigenPoint& max_pos,
bool /*dummy*/
):
voxel_size(voxel_size_) {
min_key = worldToKey(min_pos);
max_key = worldToKey(max_pos);
for (int i = 0; i < Dim; ++i) {
if (max_key[i] < min_key[i])
std::swap(max_key[i], min_key[i]);
dims[i] = max_key[i] - min_key[i] + 1;
offset[i] = 0;
}
for (int i = 0; i < Dim; ++i)
center_key[i] = (min_key[i] + max_key[i]) / 2;
EigenPoint min_world = keyToWorld(min_key);
EigenPoint max_world = keyToWorld(max_key);
for (int i = 0; i < Dim; ++i) {
center[i] = (min_world[i] + max_world[i]) * 0.5f;
size[i] = dims[i] * voxel_size;
}
size_t N = 1;
for (int i = 0; i < Dim; ++i)
N *= static_cast<size_t>(dims[i]);
grid.resize(N);
}
static Ptr create(double voxel_size, const EigenPoint& size, const EigenPoint& center) {
return std::make_shared<SlidingVoxelMap>(voxel_size, size, center);
}
size_t gridSize() const noexcept {
return grid.size();
}
inline int worldToIndex(const EigenPoint& p) const noexcept {
return keyToIndex(worldToKey(p));
}
inline Key worldToKey(const EigenPoint& p) const noexcept {
Key k;
const double inv = 1.0f / voxel_size;
for (int i = 0; i < Dim; ++i)
k.data[i] = static_cast<int>(std::floor(p[i] * inv + 1e-6f));
return k;
}
inline EigenPoint keyToWorld(const Key& k) const noexcept {
EigenPoint p;
for (int i = 0; i < Dim; ++i)
p[i] = (k.data[i] + 0.5f) * voxel_size;
return p;
}
inline EigenPoint indexToWorld(int idx) const noexcept {
return keyToWorld(indexToKey(idx));
}
inline int keyToIndex(const Key& k) const noexcept {
int idx = 0;
int stride = 1;
for (int d = Dim - 1; d >= 0; --d) {
int delta = k[d] - center_key[d] + (dims[d] >> 1);
if (delta < 0 || delta >= dims[d])
return -1;
int r = delta + offset[d];
if (r >= dims[d])
r -= dims[d];
else if (r < 0)
r += dims[d];
idx += r * stride;
stride *= dims[d];
}
return idx;
}
inline Key indexToKey(int idx) const noexcept {
Key k;
for (int d = Dim - 1; d >= 0; --d) {
int r = idx % dims[d];
idx /= dims[d];
int delta = r - offset[d];
if (delta < 0)
delta += dims[d];
else if (delta >= dims[d])
delta -= dims[d];
k[d] = center_key[d] + delta - (dims[d] >> 1);
}
return k;
}
template<typename ClearFunc>
void slideTo(const Key& new_center_key, ClearFunc clear_func) {
Key shift;
for (int d = 0; d < Dim; ++d)
shift[d] = new_center_key[d] - center_key[d];
for (int d = 0; d < Dim; ++d) {
if (std::abs(shift[d]) >= dims[d]) {
for (size_t i = 0; i < grid.size(); ++i)
clear_func(i);
offset = {};
center_key = new_center_key;
return;
}
}
for (int axis = 0; axis < Dim; ++axis) {
int s = shift[axis];
if (s == 0)
continue;
int steps = std::abs(s);
int dir = s > 0 ? 1 : -1;
for (int step = 0; step < steps; ++step) {
int slice = (offset[axis] + dir * step + dims[axis]) % dims[axis];
clearSlice(axis, slice, clear_func);
}
offset[axis] = (offset[axis] + s + dims[axis]) % dims[axis];
}
center_key = new_center_key;
EigenPoint half = size * 0.5f;
center = keyToWorld(center_key);
min_key = worldToKey(center - half);
max_key = worldToKey(center + half);
}
template<typename ClearFunc>
void clearSlice(int axis, int slice, ClearFunc clear_func) {
if constexpr (Dim == 3) {
int dx = dims[0];
int dy = dims[1];
int dz = dims[2];
if (axis == 0) {
for (int y = 0; y < dy; ++y)
for (int z = 0; z < dz; ++z) {
int idx = (slice * dy + y) * dz + z;
clear_func(idx);
}
} else if (axis == 1) {
for (int x = 0; x < dx; ++x)
for (int z = 0; z < dz; ++z) {
int idx = (x * dy + slice) * dz + z;
clear_func(idx);
}
} else {
for (int x = 0; x < dx; ++x)
for (int y = 0; y < dy; ++y) {
int idx = (x * dy + y) * dz + slice;
clear_func(idx);
}
}
}
}
public:
double voxel_size;
Key dims;
Key offset;
std::vector<Cell> grid;
Key center_key;
EigenPoint center;
EigenPoint size;
Key min_key;
Key max_key;
};
} // namespace wust_vision::auto_sniper

View File

@@ -0,0 +1,16 @@
#pragma once
#include "tasks/type_common.hpp"
#include "wust_vl/common/concurrency/monitored_thread.hpp"
#include <memory>
namespace wust_vision {
class IModule {
public:
using Ptr = std::shared_ptr<IModule>;
virtual void start() = 0;
virtual void pushInput(CommonFrame&) = 0;
virtual GimbalCmd solve(double bullet_speed) = 0;
virtual wust_vl::common::concurrency::MonitoredThread::Ptr getThread() = 0;
virtual void doDebug() = 0;
virtual ~IModule() = default;
};
} // namespace wust_vision

View File

@@ -0,0 +1,193 @@
#pragma once
#include <chrono>
#include <cstdint>
#include <fstream>
#include <mutex>
#include <nlohmann/json.hpp>
#include <sys/types.h>
namespace wust_vision {
constexpr uint8_t ID_ROBOT_CMD = 0x01;
constexpr uint8_t ID_NAV_CMD = 0x02;
constexpr uint8_t ID_AIM_INFO = 0X02;
constexpr uint8_t ID_REFEREE_INFO = 0X03;
constexpr const char* TARGET_TOPIC = "vision_target";
constexpr const char* NAV_STATE_TOPIC = "rose_state";
constexpr const char* MODE_TOPIC = "sentry_mode";
constexpr const char* ROBO_STATE_TOPIC = "robo_state";
constexpr const char* GOAL_TOPIC = "rose_goal";
struct ReceiveAimINFO {
uint8_t cmd_ID;
uint32_t time_stamp;
float yaw;
float pitch;
float roll;
float yaw_vel;
float pitch_vel;
float roll_vel;
float v_x;
float v_y;
float v_z;
float bullet_speed;
uint8_t detect_color; // 0 red 1 blue
} __attribute__((packed));
struct ReceiveReferee {
uint8_t cmd_ID;
uint32_t time_stamp;
float big_yaw_in_world;
int game_time;
int max_health;
int cur_health;
int cur_bullet;
uint8_t center_state;
} __attribute__((packed));
struct SendRobotCmdData {
uint8_t cmd_ID;
uint32_t time_stamp;
uint8_t appear;
uint8_t shoot_rate = 3;
float pitch;
float yaw;
float target_yaw;
float target_pitch;
float enable_yaw_diff;
float enable_pitch_diff;
float v_yaw;
float v_pitch;
float a_yaw;
float a_pitch;
uint8_t detect_color;
} __attribute__((packed));
constexpr uint8_t ID_NAV_CONTROL = 0;
struct NavRobotCmdData {
uint8_t cmd_ID;
uint32_t time_stamp;
uint8_t packet_type;
float vx;
float vy;
float wz;
} __attribute__((packed));
struct SerialLogBuffer {
std::mutex mtx;
nlohmann::json j;
bool dirty = false;
};
inline SerialLogBuffer& getLogBuffer() {
static SerialLogBuffer buf;
return buf;
}
inline void updateFPS(nlohmann::json& j) {
static int frame_count = 0;
static double fps = 0.0;
static auto last_time = std::chrono::steady_clock::now();
++frame_count;
auto now = std::chrono::steady_clock::now();
double elapsed = std::chrono::duration<double>(now - last_time).count();
if (elapsed >= 1.0) {
fps = frame_count / elapsed;
frame_count = 0;
last_time = now;
}
j["fps"] = fps;
}
inline void updateSerialLog(const ReceiveAimINFO& aim) {
auto& buf = getLogBuffer();
std::lock_guard<std::mutex> lock(buf.mtx);
auto& j = buf.j["aim"];
updateFPS(j);
j["timestamp"] = aim.time_stamp;
j["yaw"] = aim.yaw;
j["pitch"] = aim.pitch;
j["roll"] = aim.roll;
j["yaw_vel"] = aim.yaw_vel;
j["pitch_vel"] = aim.pitch_vel;
j["roll_vel"] = aim.roll_vel;
j["v_x"] = aim.v_x;
j["v_y"] = aim.v_y;
j["v_z"] = aim.v_z;
j["bullet_speed"] = aim.bullet_speed;
j["detect_color"] = (aim.detect_color == 0 ? "Red" : "Blue");
buf.dirty = true;
}
inline void updateSerialLog(const ReceiveReferee& ref) {
auto& buf = getLogBuffer();
std::lock_guard<std::mutex> lock(buf.mtx);
auto& j = buf.j["referee"];
updateFPS(j);
j["timestamp"] = ref.time_stamp;
j["big_yaw_in_world"] = ref.big_yaw_in_world;
j["game_time"] = ref.game_time;
j["max_health"] = ref.max_health;
j["cur_health"] = ref.cur_health;
j["cur_bullet"] = ref.cur_bullet;
j["center_state"] = ref.center_state;
buf.dirty = true;
}
inline void flushSerialLog() {
static auto last_flush = std::chrono::steady_clock::now();
auto& buf = getLogBuffer();
auto now = std::chrono::steady_clock::now();
double dt = std::chrono::duration<double>(now - last_flush).count();
// 控制写入频率(例如 20Hz
if (dt < 0.05)
return;
std::lock_guard<std::mutex> lock(buf.mtx);
if (!buf.dirty)
return;
// 更新 FPS
std::ofstream file("/dev/shm/serial_log.json");
if (file.is_open()) {
file << buf.j.dump(2);
buf.dirty = false;
}
last_flush = now;
}
} // namespace wust_vision

View File

@@ -0,0 +1,66 @@
#include "type_common.hpp"
namespace wust_vision {
std::string enemyColorToString(EnemyColor color) noexcept {
switch (color) {
case EnemyColor::RED:
return "RED";
break;
case EnemyColor::BLUE:
return "BLUE";
break;
case EnemyColor::WHITE:
return "WHITE";
break;
default:
return "UNKNOWN";
}
}
void AimTarget::tf(Eigen::Matrix4d T_camera_to_odom) noexcept {
const Eigen::Vector4d pos_camera(pos.x(), pos.y(), pos.z(), 1.0);
const Eigen::Vector4d pos_odom = T_camera_to_odom * pos_camera;
pos.x() = pos_odom.x();
pos.y() = pos_odom.y();
pos.z() = pos_odom.z();
const Eigen::Matrix3d R_camera_to_odom = T_camera_to_odom.block<3, 3>(0, 0);
const Eigen::Quaterniond q_camera(ori.w(), ori.x(), ori.y(), ori.z());
const Eigen::Matrix3d R_ori_camera = q_camera.normalized().toRotationMatrix();
const Eigen::Matrix3d R_ori_odom = R_camera_to_odom * R_ori_camera;
const Eigen::Quaterniond q_odom(R_ori_odom);
ori.w() = q_odom.w();
ori.x() = q_odom.x();
ori.y() = q_odom.y();
ori.z() = q_odom.z();
}
std::vector<cv::Point2f>
AimTarget::toPts(const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion) noexcept {
std::vector<cv::Point2f> pts;
if (pos.norm() < 0.5) {
return pts;
}
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << pos.x(), pos.y(), pos.z());
const Eigen::Matrix3d tf_rot = ori.toRotationMatrix();
const cv::Mat rot_mat =
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
tf_rot(0, 1),
tf_rot(0, 2),
tf_rot(1, 0),
tf_rot(1, 1),
tf_rot(1, 2),
tf_rot(2, 0),
tf_rot(2, 1),
tf_rot(2, 2));
cv::Mat rvec;
cv::Rodrigues(rot_mat, rvec);
cv::projectPoints(AIM_TARGET_BLOCK, rvec, tvec, camera_intrinsic, camera_distortion, pts);
return pts;
}
} // namespace wust_vision

View File

@@ -0,0 +1,310 @@
#pragma once
#include "3rdparty/angles.h"
#include "tasks/packet_typedef.hpp"
#include "wust_vl/common/utils/logger.hpp"
#include "wust_vl/common/utils/motion_buffer.hpp"
#include "wust_vl/common/utils/parameter.hpp"
#include "wust_vl/common/utils/trajectory_compensator.hpp"
#include "wust_vl/video/icamera.hpp"
#include <any>
#include <opencv2/opencv.hpp>
#include <wust_vl/video/camera.hpp>
namespace wust_vision {
struct CommonFrame {
wust_vl::video::ImageFrame img_frame;
int id;
int detect_color;
cv::Rect expanded;
cv::Point2f offset = cv::Point2f(0, 0);
std::any any_ctx;
};
enum class EnemyColor {
RED = 0,
BLUE = 1,
WHITE = 2,
};
std::string enemyColorToString(EnemyColor color) noexcept;
class InfantryMode {
public:
enum class AttackMode { ARMOR = 0, SMALL_RUNE, BIG_RUNE, UNKNOWN };
static AttackMode toAttackMode(int value) noexcept {
switch (value) {
case 0:
return AttackMode::ARMOR;
case 1:
return AttackMode::SMALL_RUNE;
case 2:
return AttackMode::BIG_RUNE;
default:
return AttackMode::UNKNOWN;
}
}
};
class HeroMode {
public:
enum class AttackMode { ARMOR = 0, SNIPER, UNKNOWN };
static AttackMode toAttackMode(int value) noexcept {
switch (value) {
case 0:
return AttackMode::ARMOR;
case 1:
return AttackMode::SNIPER;
default:
return AttackMode::UNKNOWN;
}
}
};
struct CarMotion {
double yaw, pitch, roll; // 欧拉角 (rad)
double vyaw, vpitch, vroll; // 角速度
double vx, vy, vz; // 线速度
static double unwrap_angle(double prev, double curr) noexcept {
double d = curr - prev;
while (d > M_PI) {
curr -= 2.0 * M_PI;
d -= 2.0 * M_PI;
}
while (d < -M_PI) {
curr += 2.0 * M_PI;
d += 2.0 * M_PI;
}
return curr;
}
// 角度插值wrap-safe
static double interp_angle(double a, double b, double t) noexcept {
double diff = b - a;
while (diff > M_PI)
diff -= 2.0 * M_PI;
while (diff < -M_PI)
diff += 2.0 * M_PI;
return a + diff * t;
}
};
struct BigYaw {
double big_yaw;
static double unwrap_angle(double prev, double curr) noexcept {
double d = curr - prev;
while (d > M_PI) {
curr -= 2.0 * M_PI;
d -= 2.0 * M_PI;
}
while (d < -M_PI) {
curr += 2.0 * M_PI;
d += 2.0 * M_PI;
}
return curr;
}
// 角度插值wrap-safe
static double interp_angle(double a, double b, double t) noexcept {
double diff = b - a;
while (diff > M_PI)
diff -= 2.0 * M_PI;
while (diff < -M_PI)
diff += 2.0 * M_PI;
return a + diff * t;
}
};
struct VisionCtx {
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<CarMotion, 1024>> motion_buffer;
std::shared_ptr<wust_vl::video::Camera> camera;
double communication_delay_μs;
int mode;
};
static std::vector<cv::Point3f> AIM_TARGET_BLOCK = {
{ -0.025f, -0.025f, -0.025f }, // 0: 左下前
{ 0.025f, -0.025f, -0.025f }, // 1: 右下前
{ 0.025f, 0.025f, -0.025f }, // 2: 右上前
{ -0.025f, 0.025f, -0.025f }, // 3: 左上前
{ -0.025f, -0.025f, 0.025f }, // 4: 左下后
{ 0.025f, -0.025f, 0.025f }, // 5: 右下后
{ 0.025f, 0.025f, 0.025f }, // 6: 右上后
{ -0.025f, 0.025f, 0.025f } // 7: 左上后
};
struct AimTarget {
bool valid;
Eigen::Vector3d pos = Eigen::Vector3d(0, 0, 0);
Eigen::Vector3d vel = Eigen::Vector3d(0, 0, 0);
Eigen::Quaterniond ori;
std::vector<Eigen::Vector4d> armor_posandyaw;
void tf(Eigen::Matrix4d T_camera_to_odom) noexcept;
std::vector<cv::Point2f>
toPts(const cv::Mat& camera_intrinsic, const cv::Mat& camera_distortion) noexcept;
};
struct GimbalCmd {
std::chrono::steady_clock::time_point timestamp;
double pitch = 0;
double yaw = 0;
double target_yaw = 0;
double target_pitch = 0;
double v_yaw = 0;
double v_pitch = 0;
double a_yaw = 0;
double a_pitch = 0;
double distance = -1;
bool fire_advice = false;
double enable_yaw_diff = 0;
double enable_pitch_diff = 0;
double fly_time = 0;
bool appear = false;
AimTarget aim_target;
inline bool isValid() const noexcept {
auto bad = [](double x) { return std::isnan(x) || std::isinf(x); };
if (bad(pitch) || bad(yaw) || bad(target_yaw) || bad(target_pitch) || bad(target_pitch)
|| bad(v_yaw) || bad(v_pitch) || bad(distance) || bad(enable_yaw_diff)
|| bad(enable_pitch_diff))
return false;
return true;
}
inline void noShoot() {
fire_advice = false;
enable_pitch_diff = 0;
enable_pitch_diff = 0;
}
};
struct AutoExposureCfg: wust_vl::common::utils::ParamGroup {
static constexpr const char* Logger = "Config: auto_exposure";
static constexpr const char* kKey = "auto_exposure";
const char* key() const override {
return kKey;
}
using Ptr = std::shared_ptr<AutoExposureCfg>;
AutoExposureCfg() {}
static Ptr create() {
return std::make_shared<AutoExposureCfg>();
}
GEN_PARAM(bool, enable);
GEN_PARAM(double, target_brightness);
GEN_PARAM(double, step_gain);
GEN_PARAM(double, decay_step);
GEN_PARAM(double, tolerance);
GEN_PARAM(double, exposure_min);
GEN_PARAM(double, exposure_max);
GEN_PARAM(double, control_interval_ms);
void loadSelf(const YAML::Node& node) override {
enable_param.load(node);
target_brightness_param.load(node);
step_gain_param.load(node);
decay_step_param.load(node);
tolerance_param.load(node);
exposure_min_param.load(node);
exposure_max_param.load(node);
control_interval_ms_param.load(node);
}
};
struct TFConfig: wust_vl::common::utils::ParamGroup {
public:
static constexpr const char* kKey = "tf";
static constexpr const char* Logger = "Config: common::tf";
const char* key() const override {
return kKey;
}
using Ptr = std::shared_ptr<TFConfig>;
TFConfig() {}
static Ptr create() {
return std::make_shared<TFConfig>();
}
bool first_load = false;
Eigen::Matrix3d R_camera2gimbal;
Eigen::Vector3d t_camera2gimbal;
void loadSelf(const YAML::Node& node) override {
if (!first_load) {
auto t_vec = node["t_camera2gimbal"].as<std::vector<double>>();
if (t_vec.size() != 3) {
throw std::runtime_error("YAML tf.t_camera2gimbal must have 3 elements");
}
t_camera2gimbal = Eigen::Vector3d(t_vec[0], t_vec[1], t_vec[2]);
auto R_vec = node["R_camera2gimbal"].as<std::vector<double>>();
if (R_vec.size() != 9) {
throw std::runtime_error("YAML tf.R_camera2gimbal must have 9 elements");
}
R_camera2gimbal =
Eigen::Map<const Eigen::Matrix<double, 3, 3, Eigen::RowMajor>>(R_vec.data());
first_load = true;
} else {
}
}
};
struct TrajectoryCompensatorConfig: public wust_vl::common::utils::ParamGroup {
static constexpr const char* Logger = "Config: auto_aim::trajectory_compensator";
static constexpr const char* kKey = "trajectory_compensator";
const char* key() const override {
return kKey;
}
using Ptr = std::shared_ptr<TrajectoryCompensatorConfig>;
TrajectoryCompensatorConfig() {}
static Ptr create() {
return std::make_shared<TrajectoryCompensatorConfig>();
}
std::shared_ptr<wust_vl::common::utils::TrajectoryCompensator> trajectory_compensator;
bool first_load = false;
void loadSelf(const YAML::Node& node) override
{
if (!first_load) {
std::string comp_type = node["compenstator_type"].as<std::string>("ideal");
trajectory_compensator =
wust_vl::common::utils::CompensatorFactory::createCompensator(comp_type);
trajectory_compensator->load(node);
first_load = true;
} else {
}
}
};
} // namespace wust_vision
template<>
struct wust_vl::common::utils::MotionTraits<wust_vision::CarMotion> {
static void unwrap(const wust_vision::CarMotion& prev, wust_vision::CarMotion& curr) noexcept {
curr.yaw = wust_vision::CarMotion::unwrap_angle(prev.yaw, curr.yaw);
curr.pitch = wust_vision::CarMotion::unwrap_angle(prev.pitch, curr.pitch);
curr.roll = wust_vision::CarMotion::unwrap_angle(prev.roll, curr.roll);
// 速度部分不需要 unwrap
}
static wust_vision::CarMotion interpolate(
const wust_vision::CarMotion& a,
const wust_vision::CarMotion& b,
double t
) noexcept {
wust_vision::CarMotion out;
// 欧拉角 wrap-safe 插值
out.yaw = wust_vision::CarMotion::interp_angle(a.yaw, b.yaw, t);
out.pitch = wust_vision::CarMotion::interp_angle(a.pitch, b.pitch, t);
out.roll = wust_vision::CarMotion::interp_angle(a.roll, b.roll, t);
// 角速度和线速度线性插值
out.vyaw = a.vyaw + (b.vyaw - a.vyaw) * t;
out.vpitch = a.vpitch + (b.vpitch - a.vpitch) * t;
out.vroll = a.vroll + (b.vroll - a.vroll) * t;
out.vx = a.vx + (b.vx - a.vx) * t;
out.vy = a.vy + (b.vy - a.vy) * t;
out.vz = a.vz + (b.vz - a.vz) * t;
return out;
}
};
template<>
struct wust_vl::common::utils::MotionTraits<wust_vision::BigYaw> {
static void unwrap(const wust_vision::BigYaw& prev, wust_vision::BigYaw& curr) noexcept {
curr.big_yaw = wust_vision::BigYaw::unwrap_angle(prev.big_yaw, curr.big_yaw);
}
static wust_vision::BigYaw
interpolate(const wust_vision::BigYaw& a, const wust_vision::BigYaw& b, double t) noexcept {
wust_vision::BigYaw out;
out.big_yaw = wust_vision::BigYaw::interp_angle(a.big_yaw, b.big_yaw, t);
return out;
}
};

View File

@@ -0,0 +1,64 @@
#pragma once
#include <array>
#include <cmath>
#include <iostream>
namespace wust_vision {
constexpr std::array ascii_banner = {
R"( _ ____ _____________ _ ___________ ________ _ __ )",
R"(| | / / / / / ___/_ __/ | | / / _/ ___// _/ __ \/ | / /)",
R"(| | /| / / / / /\__ \ / / | | / // / \__ \ / // / / / |/ / )",
R"(| |/ |/ / /_/ /___/ // / | |/ // / ___/ // // /_/ / /| / )",
R"(|__/|__/\____//____//_/ |___/___//____/___/\____/_/ |_/ )",
};
namespace {
struct RGB {
int r, g, b;
};
inline RGB hsv2rgb(float h, float s, float v) {
float c = v * s;
float x = c * (1 - std::fabs(std::fmod(h / 60.0f, 2) - 1));
float m = v - c;
float r = 0, g = 0, b = 0;
if (h < 60) {
r = c;
g = x;
} else if (h < 120) {
r = x;
g = c;
} else if (h < 180) {
g = c;
b = x;
} else if (h < 240) {
g = x;
b = c;
} else if (h < 300) {
r = x;
b = c;
} else {
r = c;
b = x;
}
return { int((r + m) * 255), int((g + m) * 255), int((b + m) * 255) };
}
} // namespace
inline void printBanner() {
constexpr const char* reset = "\033[0m";
for (const auto& line: ascii_banner) {
const int n = static_cast<int>(std::string_view(line).size());
for (int i = 0; i < n; ++i) {
// hue 从 0° → 360°
float hue = 360.0f * i / std::max(1, n - 1);
auto rgb = hsv2rgb(hue, 1.0f, 1.0f);
std::cout << "\033[38;2;" << rgb.r << ";" << rgb.g << ";" << rgb.b << "m" << line[i];
}
std::cout << reset << '\n';
}
}
} // namespace wust_vision

Some files were not shown because too many files have changed in this diff Show More