add wust typr mpc and mutipule x
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
add_library(tinympcstatic STATIC
|
||||
admm.cpp
|
||||
tiny_api.cpp
|
||||
codegen.cpp
|
||||
rho_benchmark.cpp
|
||||
)
|
||||
|
||||
set_property(TARGET tinympcstatic PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
# target_link_libraries(tinympcstatic PUBLIC Eigen)
|
||||
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
target_include_directories(tinympcstatic PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include/Eigen)
|
||||
|
||||
|
||||
|
||||
if(USING_CODEGEN) # Defined in top-level CMakeLists.txt
|
||||
|
||||
# Files that are needed for embedded code generation
|
||||
list( APPEND EMBEDDED_FILES
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/admm.cpp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/admm.hpp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.cpp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api.hpp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/types.hpp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/tiny_api_constants.hpp" )
|
||||
|
||||
|
||||
foreach( f ${EMBEDDED_FILES} )
|
||||
get_filename_component( fname ${f} NAME )
|
||||
|
||||
set( dest_file "${EMBEDDED_BUILD_TINYMPC_DIR}/${fname}" )
|
||||
list( APPEND EMBEDDED_BUILD_TINYMPC_FILES "${dest_file}" )
|
||||
|
||||
add_custom_command(OUTPUT ${dest_file}
|
||||
COMMAND ${CMAKE_COMMAND} -E copy "${f}" "${dest_file}"
|
||||
DEPENDS ${f}
|
||||
COMMENT "Copying ${fname}")
|
||||
endforeach()
|
||||
|
||||
add_custom_target( copy_codegen_tinympc_files DEPENDS ${EMBEDDED_BUILD_TINYMPC_FILES} )
|
||||
add_dependencies( copy_codegen_files copy_codegen_tinympc_files )
|
||||
|
||||
endif(USING_CODEGEN)
|
||||
399
wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.cpp
Normal file
399
wust_vision-main/tasks/auto_aim/armor_control/tinympc/admm.cpp
Normal file
@@ -0,0 +1,399 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "admm.hpp"
|
||||
#include "rho_benchmark.hpp"
|
||||
|
||||
#define DEBUG_MODULE "TINYALG"
|
||||
|
||||
extern "C" {
|
||||
|
||||
/**
|
||||
* Update linear terms from Riccati backward pass
|
||||
*/
|
||||
void backward_pass_grad(TinySolver* solver) {
|
||||
for (int i = solver->work->N - 2; i >= 0; i--) {
|
||||
(solver->work->d.col(i)).noalias() = solver->cache->Quu_inv
|
||||
* (solver->work->Bdyn.transpose() * solver->work->p.col(i + 1) + solver->work->r.col(i)
|
||||
+ solver->cache->BPf);
|
||||
(solver->work->p.col(i)).noalias() = solver->work->q.col(i)
|
||||
+ solver->cache->AmBKt.lazyProduct(solver->work->p.col(i + 1))
|
||||
- (solver->cache->Kinf.transpose()).lazyProduct(solver->work->r.col(i))
|
||||
+ solver->cache->APf;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Use LQR feedback policy to roll out trajectory
|
||||
*/
|
||||
void forward_pass(TinySolver* solver) {
|
||||
for (int i = 0; i < solver->work->N - 1; i++) {
|
||||
(solver->work->u.col(i)).noalias() =
|
||||
-solver->cache->Kinf.lazyProduct(solver->work->x.col(i)) - solver->work->d.col(i);
|
||||
(solver->work->x.col(i + 1)).noalias() =
|
||||
solver->work->Adyn.lazyProduct(solver->work->x.col(i))
|
||||
+ solver->work->Bdyn.lazyProduct(solver->work->u.col(i)) + solver->work->fdyn;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Project a vector s onto the second order cone defined by mu
|
||||
* @param s, mu
|
||||
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
|
||||
*/
|
||||
tinyVector project_soc(tinyVector s, float mu) {
|
||||
tinytype u0 = s(Eigen::placeholders::last) * mu;
|
||||
tinyVector u1 = s.head(s.rows() - 1);
|
||||
float a = u1.norm();
|
||||
tinyVector cone_origin(s.rows());
|
||||
cone_origin.setZero();
|
||||
|
||||
if (a <= -u0) { // below cone
|
||||
return cone_origin;
|
||||
} else if (a <= u0) { // in cone
|
||||
return s;
|
||||
} else if (a >= abs(u0)) { // outside cone
|
||||
Matrix<tinytype, 3, 1> u2(u1.size() + 1);
|
||||
u2 << u1, a / mu;
|
||||
return 0.5 * (1 + u0 / a) * u2;
|
||||
} else {
|
||||
return cone_origin;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Project a vector z onto a hyperplane defined by a^T z = b
|
||||
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ − b)/||a||² * a
|
||||
* @param z Vector to project
|
||||
* @param a Normal vector of the hyperplane
|
||||
* @param b Offset of the hyperplane
|
||||
* @return Projection of z onto the hyperplane
|
||||
*/
|
||||
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b) {
|
||||
tinytype dist = (a.dot(z) - b) / a.squaredNorm();
|
||||
return z - dist * a;
|
||||
}
|
||||
|
||||
/**
|
||||
* Project slack (auxiliary) variables into their feasible domain, defined by
|
||||
* projection functions related to each constraint
|
||||
* TODO: pass in meta information with each constraint assigning it to a
|
||||
* projection function
|
||||
*/
|
||||
void update_slack(TinySolver* solver) {
|
||||
// Update bound constraint slack variables for state
|
||||
solver->work->vnew = solver->work->x + solver->work->g;
|
||||
|
||||
// Update bound constraint slack variables for input
|
||||
solver->work->znew = solver->work->u + solver->work->y;
|
||||
|
||||
// Box constraints on state
|
||||
if (solver->settings->en_state_bound) {
|
||||
solver->work->vnew =
|
||||
solver->work->x_max.cwiseMin(solver->work->x_min.cwiseMax(solver->work->vnew));
|
||||
}
|
||||
|
||||
// Box constraints on input
|
||||
if (solver->settings->en_input_bound) {
|
||||
solver->work->znew =
|
||||
solver->work->u_max.cwiseMin(solver->work->u_min.cwiseMax(solver->work->znew));
|
||||
}
|
||||
|
||||
// Update second order cone slack variables for state
|
||||
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
|
||||
solver->work->vcnew = solver->work->x + solver->work->gc;
|
||||
}
|
||||
|
||||
// Update second order cone slack variables for input
|
||||
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
|
||||
solver->work->zcnew = solver->work->u + solver->work->yc;
|
||||
}
|
||||
|
||||
// Cone constraints on state
|
||||
if (solver->settings->en_state_soc) {
|
||||
for (int i = 0; i < solver->work->N; i++) {
|
||||
for (int k = 0; k < solver->work->numStateCones; k++) {
|
||||
int start = solver->work->Acx(k);
|
||||
int num_xs = solver->work->qcx(k);
|
||||
tinytype mu = solver->work->cx(k);
|
||||
tinyVector col = solver->work->vcnew.block(start, i, num_xs, 1);
|
||||
solver->work->vcnew.block(start, i, num_xs, 1) = project_soc(col, mu);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cone constraints on input
|
||||
if (solver->settings->en_input_soc) {
|
||||
for (int i = 0; i < solver->work->N - 1; i++) {
|
||||
for (int k = 0; k < solver->work->numInputCones; k++) {
|
||||
int start = solver->work->Acu(k);
|
||||
int num_us = solver->work->qcu(k);
|
||||
tinytype mu = solver->work->cu(k);
|
||||
tinyVector col = solver->work->zcnew.block(start, i, num_us, 1);
|
||||
solver->work->zcnew.block(start, i, num_us, 1) = project_soc(col, mu);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update linear constraint slack variables for state
|
||||
if (solver->settings->en_state_linear) {
|
||||
solver->work->vlnew = solver->work->x + solver->work->gl;
|
||||
}
|
||||
|
||||
// Update linear constraint slack variables for input
|
||||
if (solver->settings->en_input_linear) {
|
||||
solver->work->zlnew = solver->work->u + solver->work->yl;
|
||||
}
|
||||
|
||||
// Linear constraints on state
|
||||
if (solver->settings->en_state_linear) {
|
||||
for (int i = 0; i < solver->work->N; i++) {
|
||||
for (int k = 0; k < solver->work->numStateLinear; k++) {
|
||||
tinyVector a = solver->work->Alin_x.row(k);
|
||||
tinytype b = solver->work->blin_x(k);
|
||||
tinytype constraint_value = a.dot(solver->work->vlnew.col(i));
|
||||
if (constraint_value > b) { // Only project if constraint is violated
|
||||
solver->work->vlnew.col(i) =
|
||||
project_hyperplane(solver->work->vlnew.col(i), a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Linear constraints on input
|
||||
if (solver->settings->en_input_linear) {
|
||||
for (int i = 0; i < solver->work->N - 1; i++) {
|
||||
for (int k = 0; k < solver->work->numInputLinear; k++) {
|
||||
tinyVector a = solver->work->Alin_u.row(k);
|
||||
tinytype b = solver->work->blin_u(k);
|
||||
tinytype constraint_value = a.dot(solver->work->zlnew.col(i));
|
||||
if (constraint_value > b) { // Only project if constraint is violated
|
||||
solver->work->zlnew.col(i) =
|
||||
project_hyperplane(solver->work->zlnew.col(i), a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update next iteration of dual variables by performing the augmented
|
||||
* lagrangian multiplier update
|
||||
*/
|
||||
void update_dual(TinySolver* solver) {
|
||||
// Update bound constraint dual variables for state
|
||||
solver->work->g = solver->work->g + solver->work->x - solver->work->vnew;
|
||||
|
||||
// Update bound constraint dual variables for input
|
||||
solver->work->y = solver->work->y + solver->work->u - solver->work->znew;
|
||||
|
||||
// Update second order cone dual variables for state
|
||||
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
|
||||
solver->work->gc = solver->work->gc + solver->work->x - solver->work->vcnew;
|
||||
}
|
||||
|
||||
// Update second order cone dual variables for input
|
||||
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
|
||||
solver->work->yc = solver->work->yc + solver->work->u - solver->work->zcnew;
|
||||
}
|
||||
|
||||
// Update linear constraint dual variables for state
|
||||
if (solver->settings->en_state_linear) {
|
||||
solver->work->gl = solver->work->gl + solver->work->x - solver->work->vlnew;
|
||||
}
|
||||
|
||||
// Update linear constraint dual variables for input
|
||||
if (solver->settings->en_input_linear) {
|
||||
solver->work->yl = solver->work->yl + solver->work->u - solver->work->zlnew;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update linear control cost terms in the Riccati feedback using the changing
|
||||
* slack and dual variables from ADMM
|
||||
*/
|
||||
void update_linear_cost(TinySolver* solver) {
|
||||
// Update state cost terms
|
||||
solver->work->q = -(solver->work->Xref.array().colwise() * solver->work->Q.array());
|
||||
(solver->work->q).noalias() -= solver->cache->rho * (solver->work->vnew - solver->work->g);
|
||||
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
|
||||
(solver->work->q).noalias() -=
|
||||
solver->cache->rho * (solver->work->vcnew - solver->work->gc);
|
||||
}
|
||||
if (solver->settings->en_state_linear) {
|
||||
(solver->work->q).noalias() -=
|
||||
solver->cache->rho * (solver->work->vlnew - solver->work->gl);
|
||||
}
|
||||
|
||||
// Update input cost terms
|
||||
solver->work->r = -(solver->work->Uref.array().colwise() * solver->work->R.array());
|
||||
(solver->work->r).noalias() -= solver->cache->rho * (solver->work->znew - solver->work->y);
|
||||
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
|
||||
(solver->work->r).noalias() -=
|
||||
solver->cache->rho * (solver->work->zcnew - solver->work->yc);
|
||||
}
|
||||
if (solver->settings->en_input_linear) {
|
||||
(solver->work->r).noalias() -=
|
||||
solver->cache->rho * (solver->work->zlnew - solver->work->yl);
|
||||
}
|
||||
|
||||
// Update terminal cost
|
||||
solver->work->p.col(solver->work->N - 1) =
|
||||
-(solver->work->Xref.col(solver->work->N - 1).transpose().lazyProduct(solver->cache->Pinf));
|
||||
(solver->work->p.col(solver->work->N - 1)).noalias() -= solver->cache->rho
|
||||
* (solver->work->vnew.col(solver->work->N - 1) - solver->work->g.col(solver->work->N - 1));
|
||||
|
||||
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
|
||||
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
|
||||
* (solver->work->vcnew.col(solver->work->N - 1)
|
||||
- solver->work->gc.col(solver->work->N - 1));
|
||||
}
|
||||
if (solver->settings->en_state_linear) {
|
||||
solver->work->p.col(solver->work->N - 1) -= solver->cache->rho
|
||||
* (solver->work->vlnew.col(solver->work->N - 1)
|
||||
- solver->work->gl.col(solver->work->N - 1));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check for termination condition by evaluating whether the largest absolute
|
||||
* primal and dual residuals for states and inputs are below threhold.
|
||||
*/
|
||||
bool termination_condition(TinySolver* solver) {
|
||||
if (solver->work->iter % solver->settings->check_termination == 0) {
|
||||
solver->work->primal_residual_state =
|
||||
(solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
|
||||
solver->work->dual_residual_state =
|
||||
((solver->work->v - solver->work->vnew).cwiseAbs().maxCoeff()) * solver->cache->rho;
|
||||
solver->work->primal_residual_input =
|
||||
(solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
|
||||
solver->work->dual_residual_input =
|
||||
((solver->work->z - solver->work->znew).cwiseAbs().maxCoeff()) * solver->cache->rho;
|
||||
|
||||
if (solver->work->primal_residual_state < solver->settings->abs_pri_tol
|
||||
&& solver->work->primal_residual_input < solver->settings->abs_pri_tol
|
||||
&& solver->work->dual_residual_state < solver->settings->abs_dua_tol
|
||||
&& solver->work->dual_residual_input < solver->settings->abs_dua_tol)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int solve(TinySolver* solver) {
|
||||
// Initialize variables
|
||||
solver->solution->solved = 0;
|
||||
solver->solution->iter = 0;
|
||||
solver->work->status = 11; // TINY_UNSOLVED
|
||||
solver->work->iter = 0;
|
||||
|
||||
// Setup for adaptive rho
|
||||
RhoAdapter adapter;
|
||||
adapter.rho_min = solver->settings->adaptive_rho_min;
|
||||
adapter.rho_max = solver->settings->adaptive_rho_max;
|
||||
adapter.clip = solver->settings->adaptive_rho_enable_clipping;
|
||||
|
||||
RhoBenchmarkResult rho_result;
|
||||
|
||||
// Store previous values for residuals
|
||||
tinyMatrix v_prev = solver->work->vnew;
|
||||
tinyMatrix z_prev = solver->work->znew;
|
||||
|
||||
// Initialize SOC slack variables if needed
|
||||
if (solver->settings->en_state_soc && solver->work->numStateCones > 0) {
|
||||
solver->work->vcnew = solver->work->x;
|
||||
}
|
||||
|
||||
if (solver->settings->en_input_soc && solver->work->numInputCones > 0) {
|
||||
solver->work->zcnew = solver->work->u;
|
||||
}
|
||||
|
||||
// Initialize linear constraint slack variables if needed
|
||||
if (solver->settings->en_state_linear) {
|
||||
solver->work->vlnew = solver->work->x;
|
||||
}
|
||||
|
||||
if (solver->settings->en_input_linear) {
|
||||
solver->work->zlnew = solver->work->u;
|
||||
}
|
||||
|
||||
for (int i = 0; i < solver->settings->max_iter; i++) {
|
||||
// Solve linear system with Riccati and roll out to get new trajectory
|
||||
backward_pass_grad(solver);
|
||||
|
||||
forward_pass(solver);
|
||||
|
||||
// Project slack variables into feasible domain
|
||||
update_slack(solver);
|
||||
|
||||
// Compute next iteration of dual variables
|
||||
update_dual(solver);
|
||||
|
||||
// Update linear control cost terms using reference trajectory, duals, and slack variables
|
||||
update_linear_cost(solver);
|
||||
|
||||
solver->work->iter += 1;
|
||||
|
||||
// Handle adaptive rho if enabled
|
||||
if (solver->settings->adaptive_rho) {
|
||||
// Calculate residuals for adaptive rho
|
||||
tinytype pri_res_input = (solver->work->u - solver->work->znew).cwiseAbs().maxCoeff();
|
||||
tinytype pri_res_state = (solver->work->x - solver->work->vnew).cwiseAbs().maxCoeff();
|
||||
tinytype dua_res_input =
|
||||
solver->cache->rho * (solver->work->znew - z_prev).cwiseAbs().maxCoeff();
|
||||
tinytype dua_res_state =
|
||||
solver->cache->rho * (solver->work->vnew - v_prev).cwiseAbs().maxCoeff();
|
||||
|
||||
// Update rho every 5 iterations
|
||||
if (i > 0 && i % 5 == 0) {
|
||||
benchmark_rho_adaptation(
|
||||
&adapter,
|
||||
solver->work->x,
|
||||
solver->work->u,
|
||||
solver->work->vnew,
|
||||
solver->work->znew,
|
||||
solver->work->g,
|
||||
solver->work->y,
|
||||
solver->cache,
|
||||
solver->work,
|
||||
solver->work->N,
|
||||
&rho_result
|
||||
);
|
||||
|
||||
// Update matrices using Taylor expansion
|
||||
update_matrices_with_derivatives(solver->cache, rho_result.final_rho);
|
||||
}
|
||||
}
|
||||
|
||||
// Store previous values for next iteration
|
||||
z_prev = solver->work->znew;
|
||||
v_prev = solver->work->vnew;
|
||||
|
||||
// Check for whether cost is minimized by calculating residuals
|
||||
if (termination_condition(solver)) {
|
||||
solver->work->status = 1; // TINY_SOLVED
|
||||
|
||||
// Save solution
|
||||
solver->solution->iter = solver->work->iter;
|
||||
solver->solution->solved = 1;
|
||||
solver->solution->x = solver->work->vnew;
|
||||
solver->solution->u = solver->work->znew;
|
||||
|
||||
// std::cout << "Solver converged in " << solver->work->iter << " iterations" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Save previous slack variables
|
||||
solver->work->v = solver->work->vnew;
|
||||
solver->work->z = solver->work->znew;
|
||||
}
|
||||
|
||||
solver->solution->iter = solver->work->iter;
|
||||
solver->solution->solved = 0;
|
||||
solver->solution->x = solver->work->vnew;
|
||||
solver->solution->u = solver->work->znew;
|
||||
return 1;
|
||||
}
|
||||
|
||||
} /* extern "C" */
|
||||
@@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
|
||||
#include "types.hpp"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int solve(TinySolver* solver);
|
||||
|
||||
void update_primal(TinySolver* solver);
|
||||
void backward_pass_grad(TinySolver* solver);
|
||||
void forward_pass(TinySolver* solver);
|
||||
void update_slack(TinySolver* solver);
|
||||
void update_dual(TinySolver* solver);
|
||||
void update_linear_cost(TinySolver* solver);
|
||||
bool termination_condition(TinySolver* solver);
|
||||
|
||||
/**
|
||||
* Project a vector s onto the second order cone defined by mu
|
||||
* @param s, mu
|
||||
* @return projection onto cone if s is outside cone. Return s if s is inside cone.
|
||||
*/
|
||||
tinyVector project_soc(tinyVector s, float mu);
|
||||
|
||||
/**
|
||||
* Project a vector z onto a hyperplane defined by a^T z = b
|
||||
* Implements equation (21): ΠH(z) = z - (⟨z, a⟩ − b)/||a||² * a
|
||||
* @param z Vector to project
|
||||
* @param a Normal vector of the hyperplane
|
||||
* @param b Offset of the hyperplane
|
||||
* @return Projection of z onto the hyperplane
|
||||
*/
|
||||
tinyVector project_hyperplane(const tinyVector& z, const tinyVector& a, tinytype b);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,466 @@
|
||||
#include <ctype.h>
|
||||
#include <dirent.h>
|
||||
#include <stdio.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
//#include <error.h>
|
||||
#include "error.hpp"
|
||||
|
||||
#include <Eigen/Dense>
|
||||
#include <iostream>
|
||||
|
||||
// #include "types.hpp"
|
||||
#include "codegen.hpp"
|
||||
|
||||
#ifdef __MINGW32__
|
||||
#include <direct.h>
|
||||
inline int mkdir(const char* pathname, int flags) {
|
||||
return _mkdir(pathname);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* Define the maximum allowed length of the path (directory + filename + extension) */
|
||||
#define PATH_LENGTH 2048
|
||||
|
||||
using namespace Eigen;
|
||||
|
||||
static void print_matrix(FILE* f, MatrixXd mat, int num_elements) {
|
||||
// Check if matrix is uninitialized or too small
|
||||
if (mat.size() == 0 || mat.size() < num_elements) {
|
||||
// Print zeros for all elements
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
fprintf(f, "(tinytype)0.0000000000000000");
|
||||
if (i < num_elements - 1)
|
||||
fprintf(f, ",");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Matrix is properly initialized and has enough elements
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
fprintf(f, "(tinytype)%.16f", mat.reshaped<RowMajor>()[i]);
|
||||
if (i < num_elements - 1)
|
||||
fprintf(f, ",");
|
||||
}
|
||||
}
|
||||
|
||||
static void create_directory(const char* dir, int verbose) {
|
||||
// Attempt to create directory
|
||||
if (mkdir(dir, S_IRWXU | S_IRWXG | S_IROTH)) {
|
||||
if (errno == EEXIST) { // Skip if directory already exists
|
||||
if (verbose)
|
||||
std::cout << dir << " already exists, skipping." << std::endl;
|
||||
} else {
|
||||
ERROR_MSG(EXIT_FAILURE, "Failed to create directory %s", dir);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Make this fail if tiny_setup has not already been called
|
||||
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_codegen: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
int status = 0;
|
||||
status |= codegen_create_directories(output_dir, verbose);
|
||||
status |= codegen_data_header(output_dir, verbose);
|
||||
status |= codegen_data_source(solver, output_dir, verbose);
|
||||
status |= codegen_example(output_dir, verbose);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
int tiny_codegen_with_sensitivity(
|
||||
TinySolver* solver,
|
||||
const char* output_dir,
|
||||
tinyMatrix* dK,
|
||||
tinyMatrix* dP,
|
||||
tinyMatrix* dC1,
|
||||
tinyMatrix* dC2,
|
||||
int verbose
|
||||
) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_codegen_with_sensitivity: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Only store sensitivity matrices if adaptive rho is enabled
|
||||
if (solver->settings->adaptive_rho) {
|
||||
// Store the sensitivity matrices in the solver's cache
|
||||
solver->cache->dKinf_drho = *dK;
|
||||
solver->cache->dPinf_drho = *dP;
|
||||
solver->cache->dC1_drho = *dC1;
|
||||
solver->cache->dC2_drho = *dC2;
|
||||
}
|
||||
|
||||
// Call the regular codegen function which will now include the sensitivity matrices if adaptive_rho is enabled
|
||||
return tiny_codegen(solver, output_dir, verbose);
|
||||
}
|
||||
|
||||
// Create code generation folder structure in whichever directory the executable calling tiny_codegen was called
|
||||
int codegen_create_directories(const char* output_dir, int verbose) {
|
||||
// Create output folder (root folder for code generation)
|
||||
create_directory(output_dir, verbose);
|
||||
|
||||
// Create src folder
|
||||
char src_dir[PATH_LENGTH];
|
||||
sprintf(src_dir, "%s/src/", output_dir);
|
||||
create_directory(src_dir, verbose);
|
||||
|
||||
// Create tinympc folder
|
||||
char tinympc_dir[PATH_LENGTH];
|
||||
sprintf(tinympc_dir, "%s/tinympc/", output_dir);
|
||||
create_directory(tinympc_dir, verbose);
|
||||
|
||||
// // Create include folder
|
||||
// char inc_dir[PATH_LENGTH];
|
||||
// sprintf(inc_dir, "%s/include/", output_dir);
|
||||
// create_directory(inc_dir, verbose);
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
// Create inc/tiny_data.hpp file
|
||||
int codegen_data_header(const char* output_dir, int verbose) {
|
||||
char data_hpp_fname[PATH_LENGTH];
|
||||
FILE* data_hpp_f;
|
||||
|
||||
sprintf(data_hpp_fname, "%s/tinympc/tiny_data.hpp", output_dir);
|
||||
|
||||
// Open data header file
|
||||
data_hpp_f = fopen(data_hpp_fname, "w+");
|
||||
if (data_hpp_f == NULL)
|
||||
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_hpp_fname);
|
||||
|
||||
// Preamble
|
||||
time_t start_time;
|
||||
time(&start_time);
|
||||
fprintf(data_hpp_f, "/*\n");
|
||||
fprintf(data_hpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
|
||||
fprintf(data_hpp_f, " */\n\n");
|
||||
|
||||
fprintf(data_hpp_f, "#pragma once\n\n");
|
||||
|
||||
fprintf(data_hpp_f, "#include \"types.hpp\"\n\n");
|
||||
|
||||
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
|
||||
fprintf(data_hpp_f, "extern \"C\" {\n");
|
||||
fprintf(data_hpp_f, "#endif\n\n");
|
||||
|
||||
fprintf(data_hpp_f, "extern TinySolver tiny_solver;\n\n");
|
||||
|
||||
fprintf(data_hpp_f, "#ifdef __cplusplus\n");
|
||||
fprintf(data_hpp_f, "}\n");
|
||||
fprintf(data_hpp_f, "#endif\n");
|
||||
|
||||
// Close codegen data header file
|
||||
fclose(data_hpp_f);
|
||||
|
||||
if (verbose) {
|
||||
printf("Data header generated in %s\n", data_hpp_fname);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Create src/tiny_data.cpp file
|
||||
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose) {
|
||||
char data_cpp_fname[PATH_LENGTH];
|
||||
FILE* data_cpp_f;
|
||||
|
||||
int nx = solver->work->nx;
|
||||
int nu = solver->work->nu;
|
||||
int N = solver->work->N;
|
||||
|
||||
sprintf(data_cpp_fname, "%s/src/tiny_data.cpp", output_dir);
|
||||
|
||||
// Open data source file
|
||||
data_cpp_f = fopen(data_cpp_fname, "w+");
|
||||
if (data_cpp_f == NULL)
|
||||
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", data_cpp_fname);
|
||||
|
||||
// Preamble
|
||||
time_t start_time;
|
||||
time(&start_time);
|
||||
fprintf(data_cpp_f, "/*\n");
|
||||
fprintf(data_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
|
||||
fprintf(data_cpp_f, " */\n\n");
|
||||
|
||||
// Open extern C
|
||||
fprintf(data_cpp_f, "#include \"tinympc/tiny_data.hpp\"\n\n");
|
||||
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
|
||||
fprintf(data_cpp_f, "extern \"C\" {\n");
|
||||
fprintf(data_cpp_f, "#endif\n\n");
|
||||
|
||||
// Solution
|
||||
fprintf(data_cpp_f, "/* Solution */\n");
|
||||
fprintf(data_cpp_f, "TinySolution solution = {\n");
|
||||
|
||||
fprintf(data_cpp_f, "\t%d,\t\t// iter\n", solver->solution->iter);
|
||||
fprintf(data_cpp_f, "\t%d,\t\t// solved\n", solver->solution->solved);
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x solution
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// x\n"); // u solution
|
||||
|
||||
fprintf(data_cpp_f, "};\n\n");
|
||||
|
||||
// Cache
|
||||
fprintf(data_cpp_f, "/* Matrices that must be recomputed with changes in time step, rho */\n");
|
||||
fprintf(data_cpp_f, "TinyCache cache = {\n");
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// rho (step size/penalty)\n", solver->cache->rho);
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->Kinf, nu * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Kinf\n"); // Kinf
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->Pinf, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Pinf\n"); // Pinf
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nu);
|
||||
print_matrix(data_cpp_f, solver->cache->Quu_inv, nu * nu);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Quu_inv\n"); // Quu_inv
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->AmBKt, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// AmBKt\n"); // AmBKt
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->C1, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// C1\n"); // C1
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->C2, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished()"); // C2, no comma if no sensitivity matrices
|
||||
|
||||
// Only print sensitivity matrices if adaptive rho is enabled
|
||||
if (solver->settings->adaptive_rho) {
|
||||
fprintf(data_cpp_f, ",\t// C2\n"); // Add comma and comment for C2 if we have more matrices
|
||||
|
||||
// Add sensitivity matrices within the struct initialization
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->dKinf_drho, nu * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// dKinf_drho\n"); // dKinf_drho
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->dPinf_drho, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// dPinf_drho\n"); // dPinf_drho
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->dC1_drho, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// dC1_drho\n"); // dC1_drho
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->cache->dC2_drho, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished()\t// dC2_drho\n"); // dC2_drho
|
||||
} else {
|
||||
fprintf(data_cpp_f, "\t// C2\n"); // Just add comment for C2
|
||||
}
|
||||
|
||||
fprintf(data_cpp_f, "};\n\n");
|
||||
|
||||
// Settings
|
||||
fprintf(data_cpp_f, "/* User settings */\n");
|
||||
fprintf(data_cpp_f, "TinySettings settings = {\n");
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// primal tolerance\n", solver->settings->abs_pri_tol);
|
||||
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// dual tolerance\n", solver->settings->abs_dua_tol);
|
||||
fprintf(data_cpp_f, "\t%d,\t\t// max iterations\n", solver->settings->max_iter);
|
||||
fprintf(
|
||||
data_cpp_f,
|
||||
"\t%d,\t\t// iterations per termination check\n",
|
||||
solver->settings->check_termination
|
||||
);
|
||||
fprintf(data_cpp_f, "\t%d,\t\t// enable state constraints\n", solver->settings->en_state_bound);
|
||||
fprintf(data_cpp_f, "\t%d\t\t// enable input constraints\n", solver->settings->en_input_bound);
|
||||
|
||||
fprintf(data_cpp_f, "};\n\n");
|
||||
|
||||
// Workspace
|
||||
fprintf(data_cpp_f, "/* Problem variables */\n");
|
||||
fprintf(data_cpp_f, "TinyWorkspace work = {\n");
|
||||
|
||||
fprintf(data_cpp_f, "\t%d,\t// Number of states\n", nx);
|
||||
fprintf(data_cpp_f, "\t%d,\t// Number of control inputs\n", nu);
|
||||
fprintf(data_cpp_f, "\t%d,\t// Number of knotpoints in the horizon\n", N);
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// x\n"); // x
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// u\n"); // u
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// q\n"); // q
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// r\n"); // r
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// p\n"); // p
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// d\n"); // d
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// v\n"); // v
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// vnew\n"); // vnew
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// z\n"); // z
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// znew\n"); // znew
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// g\n"); // g
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// y\n"); // y
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nx);
|
||||
print_matrix(data_cpp_f, solver->work->Q, nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Q\n"); // Q
|
||||
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
|
||||
print_matrix(data_cpp_f, solver->work->R, nu);
|
||||
fprintf(data_cpp_f, ").finished(),\t// R\n"); // R
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nx);
|
||||
print_matrix(data_cpp_f, solver->work->Adyn, nx * nx);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Adyn\n"); // Adyn
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, nu);
|
||||
print_matrix(data_cpp_f, solver->work->Bdyn, nx * nu);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Bdyn\n"); // Bdyn
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, solver->work->x_min, nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// x_min\n"); // x_min
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, solver->work->x_max, nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// x_max\n"); // x_max
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, solver->work->u_min, nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// u_min\n"); // u_min
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, solver->work->u_max, nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// u_max\n"); // u_max
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nx, N);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nx, N), nx * N);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Xref\n"); // Xref
|
||||
fprintf(data_cpp_f, "\t(tinyMatrix(%d, %d) << ", nu, N - 1);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, N - 1), nu * (N - 1));
|
||||
fprintf(data_cpp_f, ").finished(),\t// Uref\n"); // Uref
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinyVector(%d) << ", nu);
|
||||
print_matrix(data_cpp_f, MatrixXd::Zero(nu, 1), nu);
|
||||
fprintf(data_cpp_f, ").finished(),\t// Qu\n"); // Qu
|
||||
|
||||
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state primal residual\n", 0.0);
|
||||
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input primal residual\n", 0.0);
|
||||
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// state dual residual\n", 0.0);
|
||||
fprintf(data_cpp_f, "\t(tinytype)%.16f,\t// input dual residual\n", 0.0);
|
||||
fprintf(data_cpp_f, "\t%d,\t// solve status\n", 0);
|
||||
fprintf(data_cpp_f, "\t%d,\t// solve iteration\n", 0);
|
||||
|
||||
fprintf(data_cpp_f, "};\n\n");
|
||||
|
||||
// Write solver struct definition to workspace file
|
||||
fprintf(data_cpp_f, "TinySolver tiny_solver = {&solution, &settings, &cache, &work};\n\n");
|
||||
|
||||
// Close extern C
|
||||
fprintf(data_cpp_f, "#ifdef __cplusplus\n");
|
||||
fprintf(data_cpp_f, "}\n");
|
||||
fprintf(data_cpp_f, "#endif\n\n");
|
||||
|
||||
// Close codegen data file
|
||||
fclose(data_cpp_f);
|
||||
if (verbose) {
|
||||
printf("Data generated in %s\n", data_cpp_fname);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int codegen_example(const char* output_dir, int verbose) {
|
||||
char example_cpp_fname[PATH_LENGTH];
|
||||
FILE* example_cpp_f;
|
||||
|
||||
sprintf(example_cpp_fname, "%s/src/tiny_main.cpp", output_dir);
|
||||
|
||||
// Open example file
|
||||
example_cpp_f = fopen(example_cpp_fname, "w+");
|
||||
if (example_cpp_f == NULL)
|
||||
ERROR_MSG(EXIT_FAILURE, "Failed to open file %s", example_cpp_fname);
|
||||
|
||||
// Preamble
|
||||
time_t start_time;
|
||||
time(&start_time);
|
||||
fprintf(example_cpp_f, "/*\n");
|
||||
fprintf(example_cpp_f, " * This file was autogenerated by TinyMPC on %s", ctime(&start_time));
|
||||
fprintf(example_cpp_f, " */\n\n");
|
||||
|
||||
fprintf(example_cpp_f, "#include <iostream>\n\n");
|
||||
|
||||
fprintf(example_cpp_f, "#include <tinympc/tiny_api.hpp>\n");
|
||||
fprintf(example_cpp_f, "#include <tinympc/tiny_data.hpp>\n\n");
|
||||
|
||||
fprintf(example_cpp_f, "using namespace Eigen;\n");
|
||||
fprintf(example_cpp_f, "IOFormat TinyFmt(4, 0, \", \", \"\\n\", \"[\", \"]\");\n\n");
|
||||
|
||||
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
|
||||
fprintf(example_cpp_f, "extern \"C\" {\n");
|
||||
fprintf(example_cpp_f, "#endif\n\n");
|
||||
|
||||
fprintf(example_cpp_f, "int main()\n");
|
||||
fprintf(example_cpp_f, "{\n");
|
||||
fprintf(example_cpp_f, "\tint exitflag = 1;\n");
|
||||
fprintf(example_cpp_f, "\t// Double check some data\n");
|
||||
fprintf(example_cpp_f, "\tstd::cout << \"rho: \" << tiny_solver.cache->rho << std::endl;\n");
|
||||
fprintf(
|
||||
example_cpp_f,
|
||||
"\tstd::cout << \"\\nmax iters: \" << tiny_solver.settings->max_iter << std::endl;\n"
|
||||
);
|
||||
fprintf(
|
||||
example_cpp_f,
|
||||
"\tstd::cout << \"\\nState transition matrix:\\n\" << tiny_solver.work->Adyn.format(TinyFmt) << std::endl;\n"
|
||||
);
|
||||
fprintf(
|
||||
example_cpp_f,
|
||||
"\tstd::cout << \"\\nInput/control matrix:\\n\" << tiny_solver.work->Bdyn.format(TinyFmt) << std::endl;\n\n"
|
||||
);
|
||||
|
||||
fprintf(
|
||||
example_cpp_f,
|
||||
"\t// Visit https://tinympc.org/ to see how to set the initial condition and update the reference trajectory.\n\n"
|
||||
);
|
||||
|
||||
fprintf(example_cpp_f, "\tstd::cout << \"\\nSolving...\\n\" << std::endl;\n\n");
|
||||
fprintf(example_cpp_f, "\texitflag = tiny_solve(&tiny_solver);\n\n");
|
||||
fprintf(example_cpp_f, "\tif (exitflag == 0) printf(\"Hooray! Solved with no error!\\n\");\n");
|
||||
fprintf(example_cpp_f, "\telse printf(\"Oops! Something went wrong!\\n\");\n");
|
||||
|
||||
fprintf(example_cpp_f, "\treturn 0;\n");
|
||||
fprintf(example_cpp_f, "}\n\n");
|
||||
|
||||
fprintf(example_cpp_f, "#ifdef __cplusplus\n");
|
||||
fprintf(example_cpp_f, "} /* extern \"C\" */\n");
|
||||
fprintf(example_cpp_f, "#endif\n");
|
||||
|
||||
// Close codegen example main file
|
||||
fclose(example_cpp_f);
|
||||
if (verbose) {
|
||||
printf("Example tinympc main generated in %s\n", example_cpp_fname);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "types.hpp"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int tiny_codegen(TinySolver* solver, const char* output_dir, int verbose);
|
||||
|
||||
int tiny_codegen_with_sensitivity(
|
||||
TinySolver* solver,
|
||||
const char* output_dir,
|
||||
tinyMatrix* dK,
|
||||
tinyMatrix* dP,
|
||||
tinyMatrix* dC1,
|
||||
tinyMatrix* dC2,
|
||||
int verbose
|
||||
);
|
||||
|
||||
int codegen_create_directories(const char* output_dir, int verbose);
|
||||
int codegen_data_header(const char* output_dir, int verbose);
|
||||
int codegen_data_source(TinySolver* solver, const char* output_dir, int verbose);
|
||||
int codegen_example(const char* output_dir, int verbose);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <errno.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
// #if defined(__linux__) || defined(__unix__)// Check if Linux
|
||||
// #include <error.h>
|
||||
// #define ERROR_MSG(exit_code, format, ...) error(exit_code, errno, format, ##__VA_ARGS__)
|
||||
|
||||
// #elif defined(__APPLE__) || defined(__MACH__) // Check if macOS
|
||||
#define ERROR_MSG(exit_code, format, ...) \
|
||||
{ \
|
||||
fprintf(stderr, format ": %s\n", ##__VA_ARGS__, strerror(errno)); \
|
||||
exit(exit_code); \
|
||||
}
|
||||
|
||||
// #else
|
||||
// #error "Unsupported operating system"
|
||||
// #endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,252 @@
|
||||
#include "rho_benchmark.hpp"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#ifdef ARDUINO
|
||||
#include <Arduino.h>
|
||||
#else
|
||||
// For non-Arduino platforms
|
||||
uint32_t micros() {
|
||||
return 0; // Replace with appropriate timing function
|
||||
}
|
||||
#endif
|
||||
|
||||
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N) {
|
||||
// Calculate dimensions
|
||||
int x_decision_size = nx * N + nu * (N - 1);
|
||||
int constraint_rows = (nx + nu) * (N - 1);
|
||||
|
||||
// Pre-allocate matrices
|
||||
adapter->A_matrix = tinyMatrix::Zero(constraint_rows, x_decision_size);
|
||||
adapter->z_vector = tinyMatrix::Zero(constraint_rows, 1);
|
||||
adapter->y_vector = tinyMatrix::Zero(constraint_rows, 1);
|
||||
adapter->x_decision = tinyMatrix::Zero(x_decision_size, 1);
|
||||
|
||||
// Pre-compute P matrix structure
|
||||
adapter->P_matrix = tinyMatrix::Zero(x_decision_size, x_decision_size);
|
||||
adapter->q_vector = tinyMatrix::Zero(x_decision_size, 1);
|
||||
|
||||
// Pre-allocate residual computation matrices
|
||||
adapter->Ax_vector = tinyMatrix::Zero(constraint_rows, 1);
|
||||
adapter->r_prim_vector = tinyMatrix::Zero(constraint_rows, 1);
|
||||
adapter->r_dual_vector = tinyMatrix::Zero(x_decision_size, 1);
|
||||
adapter->Px_vector = tinyMatrix::Zero(x_decision_size, 1);
|
||||
adapter->ATy_vector = tinyMatrix::Zero(x_decision_size, 1);
|
||||
|
||||
// Store dimensions
|
||||
adapter->format_nx = nx;
|
||||
adapter->format_nu = nu;
|
||||
adapter->format_N = N;
|
||||
|
||||
adapter->matrices_initialized = true;
|
||||
}
|
||||
|
||||
void format_matrices(
|
||||
RhoAdapter* adapter,
|
||||
const tinyMatrix& x_prev,
|
||||
const tinyMatrix& u_prev,
|
||||
const tinyMatrix& v_prev,
|
||||
const tinyMatrix& z_prev,
|
||||
const tinyMatrix& g_prev,
|
||||
const tinyMatrix& y_prev,
|
||||
TinyCache* cache,
|
||||
TinyWorkspace* work,
|
||||
int N
|
||||
) {
|
||||
if (!adapter->matrices_initialized) {
|
||||
initialize_format_matrices(adapter, x_prev.rows(), u_prev.rows(), N);
|
||||
}
|
||||
|
||||
int nx = adapter->format_nx;
|
||||
int nu = adapter->format_nu;
|
||||
|
||||
// Fill x_decision
|
||||
int x_idx = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
adapter->x_decision.block(x_idx, 0, nx, 1) = x_prev.col(i);
|
||||
x_idx += nx;
|
||||
if (i < N - 1) {
|
||||
adapter->x_decision.block(x_idx, 0, nu, 1) = u_prev.col(i);
|
||||
x_idx += nu;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear A matrix for reuse
|
||||
adapter->A_matrix.setZero();
|
||||
|
||||
// Fill A matrix with dynamics and input constraints
|
||||
for (int i = 0; i < N - 1; i++) {
|
||||
// Input constraints
|
||||
int row_start = i * nu;
|
||||
int col_start = i * (nx + nu) + nx;
|
||||
adapter->A_matrix.block(row_start, col_start, nu, nu) = tinyMatrix::Identity(nu, nu);
|
||||
|
||||
// Dynamics constraints
|
||||
row_start = (N - 1) * nu + i * nx;
|
||||
col_start = i * (nx + nu);
|
||||
adapter->A_matrix.block(row_start, col_start, nx, nx) = work->Adyn;
|
||||
adapter->A_matrix.block(row_start, col_start + nx, nx, nu) = work->Bdyn;
|
||||
|
||||
int next_state_idx = col_start + nx + nu;
|
||||
if (next_state_idx < adapter->A_matrix.cols()) {
|
||||
adapter->A_matrix.block(row_start, next_state_idx, nx, nx) =
|
||||
-tinyMatrix::Identity(nx, nx);
|
||||
}
|
||||
}
|
||||
|
||||
// Fill z and y vectors
|
||||
for (int i = 0; i < N - 1; i++) {
|
||||
adapter->z_vector.block(i * nu, 0, nu, 1) = z_prev.col(i);
|
||||
adapter->z_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = v_prev.col(i + 1);
|
||||
|
||||
adapter->y_vector.block(i * nu, 0, nu, 1) = y_prev.col(i);
|
||||
adapter->y_vector.block((N - 1) * nu + i * nx, 0, nx, 1) = g_prev.col(i + 1);
|
||||
}
|
||||
|
||||
// Build P matrix (cost matrix)
|
||||
adapter->P_matrix.setZero();
|
||||
|
||||
// Fill diagonal blocks
|
||||
x_idx = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
// State cost
|
||||
if (i == N - 1) {
|
||||
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = cache->Pinf;
|
||||
} else {
|
||||
adapter->P_matrix.block(x_idx, x_idx, nx, nx) = work->Q.asDiagonal();
|
||||
}
|
||||
x_idx += nx;
|
||||
|
||||
// Input cost
|
||||
if (i < N - 1) {
|
||||
adapter->P_matrix.block(x_idx, x_idx, nu, nu) = work->R.asDiagonal();
|
||||
x_idx += nu;
|
||||
}
|
||||
}
|
||||
|
||||
// Create q vector (linear cost vector)
|
||||
x_idx = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
// For simplicity, we'll use zero reference for now
|
||||
// In a real implementation, you'd use your reference trajectory
|
||||
tinyMatrix x_ref = tinyMatrix::Zero(nx, 1);
|
||||
tinyMatrix delta_x = x_prev.col(i) - x_ref;
|
||||
adapter->q_vector.block(x_idx, 0, nx, 1) = work->Q.asDiagonal() * delta_x;
|
||||
x_idx += nx;
|
||||
|
||||
if (i < N - 1) {
|
||||
// For simplicity, we'll use zero reference for now
|
||||
tinyMatrix u_ref = tinyMatrix::Zero(nu, 1);
|
||||
tinyMatrix delta_u = u_prev.col(i) - u_ref;
|
||||
adapter->q_vector.block(x_idx, 0, nu, 1) = work->R.asDiagonal() * delta_u;
|
||||
x_idx += nu;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute_residuals(
|
||||
RhoAdapter* adapter,
|
||||
tinytype* pri_res,
|
||||
tinytype* dual_res,
|
||||
tinytype* pri_norm,
|
||||
tinytype* dual_norm
|
||||
) {
|
||||
// Compute Ax
|
||||
adapter->Ax_vector = adapter->A_matrix * adapter->x_decision;
|
||||
|
||||
// Compute primal residual
|
||||
adapter->r_prim_vector = adapter->Ax_vector - adapter->z_vector;
|
||||
*pri_res = adapter->r_prim_vector.cwiseAbs().maxCoeff();
|
||||
*pri_norm =
|
||||
std::max(adapter->Ax_vector.cwiseAbs().maxCoeff(), adapter->z_vector.cwiseAbs().maxCoeff());
|
||||
|
||||
// Compute dual residual components
|
||||
adapter->Px_vector = adapter->P_matrix * adapter->x_decision;
|
||||
adapter->ATy_vector = adapter->A_matrix.transpose() * adapter->y_vector;
|
||||
|
||||
// Compute full dual residual
|
||||
adapter->r_dual_vector = adapter->Px_vector + adapter->q_vector + adapter->ATy_vector;
|
||||
*dual_res = adapter->r_dual_vector.cwiseAbs().maxCoeff();
|
||||
|
||||
// Compute normalization
|
||||
*dual_norm = std::max(
|
||||
std::max(
|
||||
adapter->Px_vector.cwiseAbs().maxCoeff(),
|
||||
adapter->ATy_vector.cwiseAbs().maxCoeff()
|
||||
),
|
||||
adapter->q_vector.cwiseAbs().maxCoeff()
|
||||
);
|
||||
}
|
||||
|
||||
tinytype predict_rho(
|
||||
RhoAdapter* adapter,
|
||||
tinytype pri_res,
|
||||
tinytype dual_res,
|
||||
tinytype pri_norm,
|
||||
tinytype dual_norm,
|
||||
tinytype current_rho
|
||||
) {
|
||||
const tinytype eps = 1e-10;
|
||||
|
||||
tinytype normalized_pri = pri_res / (pri_norm + eps);
|
||||
tinytype normalized_dual = dual_res / (dual_norm + eps);
|
||||
|
||||
tinytype ratio = normalized_pri / (normalized_dual + eps);
|
||||
|
||||
tinytype new_rho = current_rho * std::sqrt(ratio);
|
||||
|
||||
if (adapter->clip) {
|
||||
new_rho = std::min(std::max(new_rho, adapter->rho_min), adapter->rho_max);
|
||||
}
|
||||
|
||||
return new_rho;
|
||||
}
|
||||
|
||||
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho) {
|
||||
tinytype delta_rho = new_rho - cache->rho;
|
||||
|
||||
cache->Kinf = cache->Kinf + delta_rho * cache->dKinf_drho;
|
||||
cache->Pinf = cache->Pinf + delta_rho * cache->dPinf_drho;
|
||||
cache->C1 = cache->C1 + delta_rho * cache->dC1_drho;
|
||||
cache->C2 = cache->C2 + delta_rho * cache->dC2_drho;
|
||||
|
||||
cache->rho = new_rho;
|
||||
}
|
||||
|
||||
void benchmark_rho_adaptation(
|
||||
RhoAdapter* adapter,
|
||||
const tinyMatrix& x_prev,
|
||||
const tinyMatrix& u_prev,
|
||||
const tinyMatrix& v_prev,
|
||||
const tinyMatrix& z_prev,
|
||||
const tinyMatrix& g_prev,
|
||||
const tinyMatrix& y_prev,
|
||||
TinyCache* cache,
|
||||
TinyWorkspace* work,
|
||||
int N,
|
||||
RhoBenchmarkResult* result
|
||||
) {
|
||||
uint32_t start_time = micros();
|
||||
|
||||
// Format matrices
|
||||
format_matrices(adapter, x_prev, u_prev, v_prev, z_prev, g_prev, y_prev, cache, work, N);
|
||||
|
||||
// Compute residuals
|
||||
tinytype pri_res, dual_res, pri_norm, dual_norm;
|
||||
compute_residuals(adapter, &pri_res, &dual_res, &pri_norm, &dual_norm);
|
||||
|
||||
// Predict new rho
|
||||
tinytype new_rho = predict_rho(adapter, pri_res, dual_res, pri_norm, dual_norm, cache->rho);
|
||||
|
||||
// Update matrices
|
||||
update_matrices_with_derivatives(cache, new_rho);
|
||||
|
||||
// Store results
|
||||
result->time_us = micros() - start_time;
|
||||
result->initial_rho = cache->rho;
|
||||
result->final_rho = new_rho;
|
||||
result->pri_res = pri_res;
|
||||
result->dual_res = dual_res;
|
||||
result->pri_norm = pri_norm;
|
||||
result->dual_norm = dual_norm;
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
#pragma once
|
||||
#include "types.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
struct RhoAdapter {
|
||||
tinytype rho_min;
|
||||
tinytype rho_max;
|
||||
bool clip;
|
||||
bool matrices_initialized;
|
||||
|
||||
// Pre-allocated matrices for formatting
|
||||
tinyMatrix A_matrix;
|
||||
tinyMatrix z_vector;
|
||||
tinyMatrix y_vector;
|
||||
tinyMatrix x_decision;
|
||||
tinyMatrix P_matrix;
|
||||
tinyMatrix q_vector;
|
||||
|
||||
// Pre-allocated matrices for residual computation
|
||||
tinyMatrix Ax_vector;
|
||||
tinyMatrix r_prim_vector;
|
||||
tinyMatrix r_dual_vector;
|
||||
tinyMatrix Px_vector;
|
||||
tinyMatrix ATy_vector;
|
||||
|
||||
// Dimensions
|
||||
int format_nx;
|
||||
int format_nu;
|
||||
int format_N;
|
||||
};
|
||||
|
||||
struct RhoBenchmarkResult {
|
||||
uint32_t time_us;
|
||||
tinytype initial_rho;
|
||||
tinytype final_rho;
|
||||
tinytype pri_res;
|
||||
tinytype dual_res;
|
||||
tinytype pri_norm;
|
||||
tinytype dual_norm;
|
||||
};
|
||||
|
||||
// Initialize matrices for formatting
|
||||
void initialize_format_matrices(RhoAdapter* adapter, int nx, int nu, int N);
|
||||
|
||||
// Format matrices for residual computation
|
||||
void format_matrices(
|
||||
RhoAdapter* adapter,
|
||||
const tinyMatrix& x_prev,
|
||||
const tinyMatrix& u_prev,
|
||||
const tinyMatrix& v_prev,
|
||||
const tinyMatrix& z_prev,
|
||||
const tinyMatrix& g_prev,
|
||||
const tinyMatrix& y_prev,
|
||||
TinyCache* cache,
|
||||
TinyWorkspace* work,
|
||||
int N
|
||||
);
|
||||
|
||||
// Compute residuals
|
||||
void compute_residuals(
|
||||
RhoAdapter* adapter,
|
||||
tinytype* pri_res,
|
||||
tinytype* dual_res,
|
||||
tinytype* pri_norm,
|
||||
tinytype* dual_norm
|
||||
);
|
||||
|
||||
// Predict new rho value
|
||||
tinytype predict_rho(
|
||||
RhoAdapter* adapter,
|
||||
tinytype pri_res,
|
||||
tinytype dual_res,
|
||||
tinytype pri_norm,
|
||||
tinytype dual_norm,
|
||||
tinytype current_rho
|
||||
);
|
||||
|
||||
// Update matrices using derivatives
|
||||
void update_matrices_with_derivatives(TinyCache* cache, tinytype new_rho);
|
||||
|
||||
// Main benchmark function
|
||||
void benchmark_rho_adaptation(
|
||||
RhoAdapter* adapter,
|
||||
const tinyMatrix& x_prev,
|
||||
const tinyMatrix& u_prev,
|
||||
const tinyMatrix& v_prev,
|
||||
const tinyMatrix& z_prev,
|
||||
const tinyMatrix& g_prev,
|
||||
const tinyMatrix& y_prev,
|
||||
TinyCache* cache,
|
||||
TinyWorkspace* work,
|
||||
int N,
|
||||
RhoBenchmarkResult* result
|
||||
);
|
||||
@@ -0,0 +1,876 @@
|
||||
#include "tiny_api.hpp"
|
||||
#include "tiny_api_constants.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
using namespace Eigen;
|
||||
IOFormat TinyApiFmt(4, 0, ", ", "\n", "[", "]");
|
||||
|
||||
static int
|
||||
check_dimension(std::string matrix_name, std::string rows_or_columns, int actual, int expected) {
|
||||
if (actual != expected) {
|
||||
std::cout << matrix_name << " has " << actual << " " << rows_or_columns << ". Expected "
|
||||
<< expected << "." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_setup(
|
||||
TinySolver** solverp,
|
||||
tinyMatrix Adyn,
|
||||
tinyMatrix Bdyn,
|
||||
tinyMatrix fdyn,
|
||||
tinyMatrix Q,
|
||||
tinyMatrix R,
|
||||
tinytype rho,
|
||||
int nx,
|
||||
int nu,
|
||||
int N,
|
||||
int verbose
|
||||
) {
|
||||
TinySolution* solution = new TinySolution();
|
||||
TinyCache* cache = new TinyCache();
|
||||
TinySettings* settings = new TinySettings();
|
||||
TinyWorkspace* work = new TinyWorkspace();
|
||||
TinySolver* solver = new TinySolver();
|
||||
|
||||
solver->solution = solution;
|
||||
solver->cache = cache;
|
||||
solver->settings = settings;
|
||||
solver->work = work;
|
||||
|
||||
*solverp = solver;
|
||||
|
||||
// Initialize solution
|
||||
solution->iter = 0;
|
||||
solution->solved = 0;
|
||||
solution->x = tinyMatrix::Zero(nx, N);
|
||||
solution->u = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
// Initialize settings
|
||||
tiny_set_default_settings(settings);
|
||||
|
||||
// Initialize workspace
|
||||
work->nx = nx;
|
||||
work->nu = nu;
|
||||
work->N = N;
|
||||
|
||||
// Make sure arguments are the correct shapes
|
||||
int status = 0;
|
||||
status |= check_dimension("State transition matrix (A)", "rows", Adyn.rows(), nx);
|
||||
status |= check_dimension("State transition matrix (A)", "columns", Adyn.cols(), nx);
|
||||
status |= check_dimension("Input matrix (B)", "rows", Bdyn.rows(), nx);
|
||||
status |= check_dimension("Input matrix (B)", "columns", Bdyn.cols(), nu);
|
||||
status |= check_dimension("Affine vector (f)", "rows", fdyn.rows(), nx);
|
||||
status |= check_dimension("Affine vector (f)", "columns", fdyn.cols(), 1);
|
||||
status |= check_dimension("State stage cost (Q)", "rows", Q.rows(), nx);
|
||||
status |= check_dimension("State stage cost (Q)", "columns", Q.cols(), nx);
|
||||
status |= check_dimension("State input cost (R)", "rows", R.rows(), nu);
|
||||
status |= check_dimension("State input cost (R)", "columns", R.cols(), nu);
|
||||
if (status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
work->x = tinyMatrix::Zero(nx, N);
|
||||
work->u = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
work->q = tinyMatrix::Zero(nx, N);
|
||||
work->r = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
work->p = tinyMatrix::Zero(nx, N);
|
||||
work->d = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
// Bound constraint slack variables
|
||||
work->v = tinyMatrix::Zero(nx, N);
|
||||
work->vnew = tinyMatrix::Zero(nx, N);
|
||||
work->z = tinyMatrix::Zero(nu, N - 1);
|
||||
work->znew = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
// Bound constraint dual variables
|
||||
work->g = tinyMatrix::Zero(nx, N);
|
||||
work->y = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
// Cone constraint slack variables
|
||||
work->vc = tinyMatrix::Zero(nx, N);
|
||||
work->vcnew = tinyMatrix::Zero(nx, N);
|
||||
work->zc = tinyMatrix::Zero(nu, N - 1);
|
||||
work->zcnew = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
// Cone constraint dual variables
|
||||
work->gc = tinyMatrix::Zero(nx, N);
|
||||
work->yc = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
// Linear constraint slack variables
|
||||
work->vl = tinyMatrix::Zero(nx, N);
|
||||
work->vlnew = tinyMatrix::Zero(nx, N);
|
||||
work->zl = tinyMatrix::Zero(nu, N - 1);
|
||||
work->zlnew = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
// Linear constraint dual variables
|
||||
work->gl = tinyMatrix::Zero(nx, N);
|
||||
work->yl = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
work->Q = (Q + rho * tinyMatrix::Identity(nx, nx)).diagonal();
|
||||
work->R = (R + rho * tinyMatrix::Identity(nu, nu)).diagonal();
|
||||
work->Adyn = Adyn; // State transition matrix
|
||||
work->Bdyn = Bdyn; // Input matrix
|
||||
work->fdyn = fdyn; // Affine offset vector
|
||||
|
||||
work->Xref = tinyMatrix::Zero(nx, N);
|
||||
work->Uref = tinyMatrix::Zero(nu, N - 1);
|
||||
|
||||
work->Qu = tinyVector::Zero(nu);
|
||||
|
||||
work->primal_residual_state = 0;
|
||||
work->primal_residual_input = 0;
|
||||
work->dual_residual_state = 0;
|
||||
work->dual_residual_input = 0;
|
||||
work->status = 0;
|
||||
work->iter = 0;
|
||||
|
||||
// Initialize cache
|
||||
status = tiny_precompute_and_set_cache(
|
||||
cache,
|
||||
Adyn,
|
||||
Bdyn,
|
||||
fdyn,
|
||||
work->Q.asDiagonal(),
|
||||
work->R.asDiagonal(),
|
||||
nx,
|
||||
nu,
|
||||
rho,
|
||||
verbose
|
||||
);
|
||||
if (status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Initialize sensitivity matrices for adaptive rho
|
||||
if (solver->settings->adaptive_rho) {
|
||||
tiny_initialize_sensitivity_matrices(solver);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_set_bound_constraints(
|
||||
TinySolver* solver,
|
||||
tinyMatrix x_min,
|
||||
tinyMatrix x_max,
|
||||
tinyMatrix u_min,
|
||||
tinyMatrix u_max
|
||||
) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_set_bound_constraints: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Make sure all bound constraint matrix sizes are self-consistent
|
||||
int status = 0;
|
||||
status |= check_dimension("Lower state bounds (x_min)", "rows", x_min.rows(), solver->work->nx);
|
||||
status |= check_dimension("Lower state bounds (x_min)", "cols", x_min.cols(), solver->work->N);
|
||||
status |= check_dimension("Lower state bounds (x_max)", "rows", x_max.rows(), solver->work->nx);
|
||||
status |= check_dimension("Lower state bounds (x_max)", "cols", x_max.cols(), solver->work->N);
|
||||
status |= check_dimension("Lower input bounds (u_min)", "rows", u_min.rows(), solver->work->nu);
|
||||
status |=
|
||||
check_dimension("Lower input bounds (u_min)", "cols", u_min.cols(), solver->work->N - 1);
|
||||
status |= check_dimension("Lower input bounds (u_max)", "rows", u_max.rows(), solver->work->nu);
|
||||
status |=
|
||||
check_dimension("Lower input bounds (u_max)", "cols", u_max.cols(), solver->work->N - 1);
|
||||
|
||||
solver->work->x_min = x_min;
|
||||
solver->work->x_max = x_max;
|
||||
solver->work->u_min = u_min;
|
||||
solver->work->u_max = u_max;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_set_cone_constraints(
|
||||
TinySolver* solver,
|
||||
VectorXi Acx,
|
||||
VectorXi qcx,
|
||||
tinyVector cx,
|
||||
VectorXi Acu,
|
||||
VectorXi qcu,
|
||||
tinyVector cu
|
||||
) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_set_cone_constraints: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Make sure all cone constraint vector sizes are self-consistent
|
||||
int num_state_cones = Acx.rows();
|
||||
int num_input_cones = Acu.rows();
|
||||
int status = 0;
|
||||
status |= check_dimension("Cone state size (qcx)", "rows", qcx.rows(), num_state_cones);
|
||||
status |= check_dimension("Cone mu value for state (cx)", "rows", cx.rows(), num_state_cones);
|
||||
status |= check_dimension("Cone input size (qcu)", "rows", qcu.rows(), num_input_cones);
|
||||
status |= check_dimension("Cone mu value for input (cu)", "rows", cu.rows(), num_input_cones);
|
||||
if (status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
solver->work->numStateCones = num_state_cones;
|
||||
solver->work->numInputCones = num_input_cones;
|
||||
|
||||
solver->work->Acx = Acx;
|
||||
solver->work->qcx = qcx;
|
||||
solver->work->cx = cx;
|
||||
|
||||
solver->work->Acu = Acu;
|
||||
solver->work->qcu = qcu;
|
||||
solver->work->cu = cu;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_set_linear_constraints(
|
||||
TinySolver* solver,
|
||||
tinyMatrix Alin_x,
|
||||
tinyVector blin_x,
|
||||
tinyMatrix Alin_u,
|
||||
tinyVector blin_u
|
||||
) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_set_linear_constraints: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Make sure all linear constraint matrix sizes are self-consistent
|
||||
int num_state_linear = Alin_x.rows();
|
||||
int num_input_linear = Alin_u.rows();
|
||||
int status = 0;
|
||||
|
||||
// Check state constraint dimensions
|
||||
if (num_state_linear > 0) {
|
||||
status |= check_dimension(
|
||||
"State linear constraint matrix (Alin_x)",
|
||||
"rows",
|
||||
Alin_x.rows(),
|
||||
num_state_linear
|
||||
);
|
||||
status |= check_dimension(
|
||||
"State linear constraint matrix (Alin_x)",
|
||||
"columns",
|
||||
Alin_x.cols(),
|
||||
solver->work->nx
|
||||
);
|
||||
status |= check_dimension(
|
||||
"State linear constraint vector (blin_x)",
|
||||
"rows",
|
||||
blin_x.rows(),
|
||||
num_state_linear
|
||||
);
|
||||
status |=
|
||||
check_dimension("State linear constraint vector (blin_x)", "columns", blin_x.cols(), 1);
|
||||
}
|
||||
|
||||
// Check input constraint dimensions
|
||||
if (num_input_linear > 0) {
|
||||
status |= check_dimension(
|
||||
"Input linear constraint matrix (Alin_u)",
|
||||
"rows",
|
||||
Alin_u.rows(),
|
||||
num_input_linear
|
||||
);
|
||||
status |= check_dimension(
|
||||
"Input linear constraint matrix (Alin_u)",
|
||||
"columns",
|
||||
Alin_u.cols(),
|
||||
solver->work->nu
|
||||
);
|
||||
status |= check_dimension(
|
||||
"Input linear constraint vector (blin_u)",
|
||||
"rows",
|
||||
blin_u.rows(),
|
||||
num_input_linear
|
||||
);
|
||||
status |=
|
||||
check_dimension("Input linear constraint vector (blin_u)", "columns", blin_u.cols(), 1);
|
||||
}
|
||||
|
||||
if (status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
solver->work->numStateLinear = num_state_linear;
|
||||
solver->work->numInputLinear = num_input_linear;
|
||||
|
||||
solver->work->Alin_x = Alin_x;
|
||||
solver->work->blin_x = blin_x;
|
||||
solver->work->Alin_u = Alin_u;
|
||||
solver->work->blin_u = blin_u;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_precompute_and_set_cache(
|
||||
TinyCache* cache,
|
||||
tinyMatrix Adyn,
|
||||
tinyMatrix Bdyn,
|
||||
tinyMatrix fdyn,
|
||||
tinyMatrix Q,
|
||||
tinyMatrix R,
|
||||
int nx,
|
||||
int nu,
|
||||
tinytype rho,
|
||||
int verbose
|
||||
) {
|
||||
if (!cache) {
|
||||
std::cout << "Error in tiny_precompute_and_set_cache: cache is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Update by adding rho * identity matrix to Q, R
|
||||
tinyMatrix Q1 = Q + rho * tinyMatrix::Identity(nx, nx);
|
||||
tinyMatrix R1 = R + rho * tinyMatrix::Identity(nu, nu);
|
||||
|
||||
// Printing
|
||||
if (verbose) {
|
||||
std::cout << "A = " << Adyn.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "B = " << Bdyn.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "Q = " << Q1.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "R = " << R1.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "rho = " << rho << std::endl;
|
||||
}
|
||||
|
||||
// Riccati recursion to get Kinf, Pinf
|
||||
tinyMatrix Ktp1 = tinyMatrix::Zero(nu, nx);
|
||||
tinyMatrix Ptp1 = rho * tinyMatrix::Ones(nx, 1).array().matrix().asDiagonal();
|
||||
tinyMatrix Kinf = tinyMatrix::Zero(nu, nx);
|
||||
tinyMatrix Pinf = tinyMatrix::Zero(nx, nx);
|
||||
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
Kinf = (R1 + Bdyn.transpose() * Ptp1 * Bdyn).inverse() * Bdyn.transpose() * Ptp1 * Adyn;
|
||||
Pinf = Q1 + Adyn.transpose() * Ptp1 * (Adyn - Bdyn * Kinf);
|
||||
// if Kinf converges, break
|
||||
if ((Kinf - Ktp1).cwiseAbs().maxCoeff() < 1e-5) {
|
||||
if (verbose) {
|
||||
std::cout << "Kinf converged after " << i + 1 << " iterations" << std::endl;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ktp1 = Kinf;
|
||||
Ptp1 = Pinf;
|
||||
}
|
||||
|
||||
// Compute cached matrices
|
||||
tinyMatrix Quu_inv = (R1 + Bdyn.transpose() * Pinf * Bdyn).inverse();
|
||||
tinyMatrix AmBKt = (Adyn - Bdyn * Kinf).transpose();
|
||||
|
||||
// Precomputation for affine term
|
||||
tinyVector APf = AmBKt * Pinf * fdyn;
|
||||
tinyVector BPf = Bdyn.transpose() * Pinf * fdyn;
|
||||
|
||||
if (verbose) {
|
||||
std::cout << "Kinf = " << Kinf.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "Pinf = " << Pinf.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "Quu_inv = " << Quu_inv.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "AmBKt = " << AmBKt.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "APf = " << APf.format(TinyApiFmt) << std::endl;
|
||||
std::cout << "BPf = " << BPf.format(TinyApiFmt) << std::endl;
|
||||
|
||||
std::cout << "\nPrecomputation finished!\n" << std::endl;
|
||||
}
|
||||
|
||||
cache->rho = rho;
|
||||
cache->Kinf = Kinf;
|
||||
cache->Pinf = Pinf;
|
||||
cache->Quu_inv = Quu_inv;
|
||||
cache->AmBKt = AmBKt;
|
||||
cache->C1 = Quu_inv;
|
||||
cache->C2 = AmBKt;
|
||||
cache->APf = APf;
|
||||
cache->BPf = BPf;
|
||||
|
||||
return 0; // return success
|
||||
}
|
||||
|
||||
int tiny_solve(TinySolver* solver) {
|
||||
return solve(solver);
|
||||
}
|
||||
|
||||
int tiny_update_settings(
|
||||
TinySettings* settings,
|
||||
tinytype abs_pri_tol,
|
||||
tinytype abs_dua_tol,
|
||||
int max_iter,
|
||||
int check_termination,
|
||||
int en_state_bound,
|
||||
int en_input_bound,
|
||||
int en_state_soc,
|
||||
int en_input_soc,
|
||||
int en_state_linear,
|
||||
int en_input_linear
|
||||
) {
|
||||
if (!settings) {
|
||||
std::cout << "Error in tiny_update_settings: settings is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
settings->abs_pri_tol = abs_pri_tol;
|
||||
settings->abs_dua_tol = abs_dua_tol;
|
||||
settings->max_iter = max_iter;
|
||||
settings->check_termination = check_termination;
|
||||
settings->en_state_bound = en_state_bound;
|
||||
settings->en_input_bound = en_input_bound;
|
||||
settings->en_state_soc = en_state_soc;
|
||||
settings->en_input_soc = en_input_soc;
|
||||
settings->en_state_linear = en_state_linear;
|
||||
settings->en_input_linear = en_input_linear;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_set_default_settings(TinySettings* settings) {
|
||||
if (!settings) {
|
||||
std::cout << "Error in tiny_set_default_settings: settings is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
settings->abs_pri_tol = TINY_DEFAULT_ABS_PRI_TOL;
|
||||
settings->abs_dua_tol = TINY_DEFAULT_ABS_DUA_TOL;
|
||||
settings->max_iter = TINY_DEFAULT_MAX_ITER;
|
||||
settings->check_termination = TINY_DEFAULT_CHECK_TERMINATION;
|
||||
|
||||
// Turn off constraints until they are set by tiny_set_bound_constraints or tiny_set_cone_constraints
|
||||
settings->en_state_bound = TINY_DEFAULT_EN_STATE_BOUND;
|
||||
settings->en_input_bound = TINY_DEFAULT_EN_INPUT_BOUND;
|
||||
settings->en_state_soc = TINY_DEFAULT_EN_STATE_SOC;
|
||||
settings->en_input_soc = TINY_DEFAULT_EN_INPUT_SOC;
|
||||
settings->en_state_linear = TINY_DEFAULT_EN_STATE_LINEAR;
|
||||
settings->en_input_linear = TINY_DEFAULT_EN_INPUT_LINEAR;
|
||||
|
||||
// Initialize adaptive rho settings
|
||||
// NOTE : Adaptive rho currently supports only quadrotor system
|
||||
settings->adaptive_rho = 0; // Disabled by default
|
||||
settings->adaptive_rho_min = 1.0;
|
||||
settings->adaptive_rho_max = 100.0;
|
||||
settings->adaptive_rho_enable_clipping = 1;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_set_x0(TinySolver* solver, tinyVector x0) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_set_x0: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
if (x0.rows() != solver->work->nx) {
|
||||
perror("Error in tiny_set_x0: x0 is not the correct length");
|
||||
}
|
||||
solver->work->x.col(0) = x0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_set_x_ref: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
int status = 0;
|
||||
status |= check_dimension(
|
||||
"State reference trajectory (x_ref)",
|
||||
"rows",
|
||||
x_ref.rows(),
|
||||
solver->work->nx
|
||||
);
|
||||
status |= check_dimension(
|
||||
"State reference trajectory (x_ref)",
|
||||
"columns",
|
||||
x_ref.cols(),
|
||||
solver->work->N
|
||||
);
|
||||
solver->work->Xref = x_ref;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref) {
|
||||
if (!solver) {
|
||||
std::cout << "Error in tiny_set_u_ref: solver is nullptr" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
int status = 0;
|
||||
status |= check_dimension(
|
||||
"Control/input reference trajectory (u_ref)",
|
||||
"rows",
|
||||
u_ref.rows(),
|
||||
solver->work->nu
|
||||
);
|
||||
status |= check_dimension(
|
||||
"Control/input reference trajectory (u_ref)",
|
||||
"columns",
|
||||
u_ref.cols(),
|
||||
solver->work->N - 1
|
||||
);
|
||||
solver->work->Uref = u_ref;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void tiny_initialize_sensitivity_matrices(TinySolver* solver) {
|
||||
int nu = solver->work->nu;
|
||||
int nx = solver->work->nx;
|
||||
// Initialize matrices with zeros
|
||||
solver->cache->dKinf_drho = tinyMatrix::Zero(nu, nx);
|
||||
solver->cache->dPinf_drho = tinyMatrix::Zero(nx, nx);
|
||||
solver->cache->dC1_drho = tinyMatrix::Zero(nu, nu);
|
||||
solver->cache->dC2_drho = tinyMatrix::Zero(nx, nx);
|
||||
|
||||
const float dKinf_drho[4][12] = { { 0.0001,
|
||||
-0.0001,
|
||||
-0.0025,
|
||||
0.0003,
|
||||
0.0007,
|
||||
0.0050,
|
||||
0.0001,
|
||||
-0.0001,
|
||||
-0.0008,
|
||||
0.0000,
|
||||
0.0001,
|
||||
0.0008 },
|
||||
{ -0.0001,
|
||||
-0.0000,
|
||||
-0.0025,
|
||||
-0.0001,
|
||||
-0.0006,
|
||||
-0.0050,
|
||||
-0.0001,
|
||||
0.0000,
|
||||
-0.0008,
|
||||
-0.0000,
|
||||
-0.0001,
|
||||
-0.0008 },
|
||||
{ 0.0000,
|
||||
0.0000,
|
||||
-0.0025,
|
||||
0.0001,
|
||||
0.0004,
|
||||
0.0050,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0008,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0008 },
|
||||
{ -0.0000,
|
||||
0.0001,
|
||||
-0.0025,
|
||||
-0.0003,
|
||||
-0.0004,
|
||||
-0.0050,
|
||||
-0.0000,
|
||||
0.0001,
|
||||
-0.0008,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
-0.0008 } };
|
||||
|
||||
const float dPinf_drho[12][12] = { { 0.0494,
|
||||
-0.0045,
|
||||
-0.0000,
|
||||
0.0110,
|
||||
0.1300,
|
||||
-0.0283,
|
||||
0.0280,
|
||||
-0.0026,
|
||||
-0.0000,
|
||||
0.0004,
|
||||
0.0070,
|
||||
-0.0094 },
|
||||
{ -0.0045,
|
||||
0.0491,
|
||||
0.0000,
|
||||
-0.1320,
|
||||
-0.0111,
|
||||
0.0114,
|
||||
-0.0026,
|
||||
0.0279,
|
||||
0.0000,
|
||||
-0.0076,
|
||||
-0.0004,
|
||||
0.0038 },
|
||||
{ -0.0000,
|
||||
0.0000,
|
||||
2.4450,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
1.2593,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000 },
|
||||
{ 0.0110,
|
||||
-0.1320,
|
||||
0.0000,
|
||||
0.3913,
|
||||
0.0592,
|
||||
0.3108,
|
||||
0.0080,
|
||||
-0.0776,
|
||||
0.0000,
|
||||
0.0254,
|
||||
0.0068,
|
||||
0.0750 },
|
||||
{ 0.1300,
|
||||
-0.0111,
|
||||
-0.0000,
|
||||
0.0592,
|
||||
0.4420,
|
||||
0.7771,
|
||||
0.0797,
|
||||
-0.0081,
|
||||
-0.0000,
|
||||
0.0068,
|
||||
0.0350,
|
||||
0.1875 },
|
||||
{ -0.0283,
|
||||
0.0114,
|
||||
-0.0000,
|
||||
0.3108,
|
||||
0.7771,
|
||||
10.0441,
|
||||
0.0272,
|
||||
-0.0109,
|
||||
0.0000,
|
||||
0.0655,
|
||||
0.1639,
|
||||
2.6362 },
|
||||
{ 0.0280,
|
||||
-0.0026,
|
||||
-0.0000,
|
||||
0.0080,
|
||||
0.0797,
|
||||
0.0272,
|
||||
0.0163,
|
||||
-0.0016,
|
||||
-0.0000,
|
||||
0.0005,
|
||||
0.0047,
|
||||
0.0032 },
|
||||
{ -0.0026,
|
||||
0.0279,
|
||||
0.0000,
|
||||
-0.0776,
|
||||
-0.0081,
|
||||
-0.0109,
|
||||
-0.0016,
|
||||
0.0161,
|
||||
0.0000,
|
||||
-0.0046,
|
||||
-0.0005,
|
||||
-0.0013 },
|
||||
{ -0.0000,
|
||||
0.0000,
|
||||
1.2593,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.9232,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000 },
|
||||
{ 0.0004,
|
||||
-0.0076,
|
||||
0.0000,
|
||||
0.0254,
|
||||
0.0068,
|
||||
0.0655,
|
||||
0.0005,
|
||||
-0.0046,
|
||||
0.0000,
|
||||
0.0022,
|
||||
0.0017,
|
||||
0.0244 },
|
||||
{ 0.0070,
|
||||
-0.0004,
|
||||
0.0000,
|
||||
0.0068,
|
||||
0.0350,
|
||||
0.1639,
|
||||
0.0047,
|
||||
-0.0005,
|
||||
0.0000,
|
||||
0.0017,
|
||||
0.0054,
|
||||
0.0610 },
|
||||
{ -0.0094,
|
||||
0.0038,
|
||||
0.0000,
|
||||
0.0750,
|
||||
0.1875,
|
||||
2.6362,
|
||||
0.0032,
|
||||
-0.0013,
|
||||
0.0000,
|
||||
0.0244,
|
||||
0.0610,
|
||||
0.9869 } };
|
||||
|
||||
const float dC1_drho[4][4] = { { -0.0000, 0.0000, -0.0000, 0.0000 },
|
||||
{ 0.0000, -0.0000, 0.0000, -0.0000 },
|
||||
{ -0.0000, 0.0000, -0.0000, 0.0000 },
|
||||
{ 0.0000, -0.0000, 0.0000, -0.0000 } };
|
||||
|
||||
const float dC2_drho[12][12] = { { 0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000 },
|
||||
{ -0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000 },
|
||||
{ -0.0000,
|
||||
0.0000,
|
||||
0.0001,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000 },
|
||||
{ 0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0001,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000 },
|
||||
{ 0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0001,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000 },
|
||||
{ -0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0001,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000 },
|
||||
{ 0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000 },
|
||||
{ -0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000 },
|
||||
{ -0.0000,
|
||||
0.0000,
|
||||
0.0021,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0006,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
-0.0000 },
|
||||
{ 0.0002,
|
||||
-0.0027,
|
||||
-0.0000,
|
||||
0.0068,
|
||||
0.0005,
|
||||
-0.0005,
|
||||
0.0001,
|
||||
-0.0015,
|
||||
-0.0000,
|
||||
0.0004,
|
||||
0.0000,
|
||||
-0.0001 },
|
||||
{ 0.0027,
|
||||
-0.0002,
|
||||
0.0000,
|
||||
0.0005,
|
||||
0.0066,
|
||||
-0.0011,
|
||||
0.0015,
|
||||
-0.0001,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0004,
|
||||
-0.0002 },
|
||||
{ -0.0001,
|
||||
0.0001,
|
||||
0.0000,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0041,
|
||||
-0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0006 } };
|
||||
|
||||
// Map arrays to Eigen matrices
|
||||
solver->cache->dKinf_drho = Map<const Matrix<float, 4, 12>>(dKinf_drho[0]).cast<tinytype>();
|
||||
solver->cache->dPinf_drho = Map<const Matrix<float, 12, 12>>(dPinf_drho[0]).cast<tinytype>();
|
||||
solver->cache->dC1_drho = Map<const Matrix<float, 4, 4>>(dC1_drho[0]).cast<tinytype>();
|
||||
solver->cache->dC2_drho = Map<const Matrix<float, 12, 12>>(dC2_drho[0]).cast<tinytype>();
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,118 @@
|
||||
#pragma once
|
||||
|
||||
#include "admm.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int tiny_setup(
|
||||
TinySolver** solverp,
|
||||
tinyMatrix Adyn,
|
||||
tinyMatrix Bdyn,
|
||||
tinyMatrix fdyn,
|
||||
tinyMatrix Q,
|
||||
tinyMatrix R,
|
||||
tinytype rho,
|
||||
int nx,
|
||||
int nu,
|
||||
int N,
|
||||
int verbose
|
||||
);
|
||||
int tiny_set_bound_constraints(
|
||||
TinySolver* solver,
|
||||
tinyMatrix x_min,
|
||||
tinyMatrix x_max,
|
||||
tinyMatrix u_min,
|
||||
tinyMatrix u_max
|
||||
);
|
||||
int tiny_set_cone_constraints(
|
||||
TinySolver* solver,
|
||||
VectorXi Acu,
|
||||
VectorXi qcu,
|
||||
tinyVector cu,
|
||||
VectorXi Acx,
|
||||
VectorXi qcx,
|
||||
tinyVector cx
|
||||
);
|
||||
int tiny_set_linear_constraints(
|
||||
TinySolver* solver,
|
||||
tinyMatrix Alin_x,
|
||||
tinyVector blin_x,
|
||||
tinyMatrix Alin_u,
|
||||
tinyVector blin_u
|
||||
);
|
||||
int tiny_precompute_and_set_cache(
|
||||
TinyCache* cache,
|
||||
tinyMatrix Adyn,
|
||||
tinyMatrix Bdyn,
|
||||
tinyMatrix fdyn,
|
||||
tinyMatrix Q,
|
||||
tinyMatrix R,
|
||||
int nx,
|
||||
int nu,
|
||||
tinytype rho,
|
||||
int verbose
|
||||
);
|
||||
|
||||
void compute_sensitivity_matrices(
|
||||
TinyCache* cache,
|
||||
tinyMatrix Adyn,
|
||||
tinyMatrix Bdyn,
|
||||
tinyMatrix Q,
|
||||
tinyMatrix R,
|
||||
int nx,
|
||||
int nu,
|
||||
tinytype rho,
|
||||
int verbose
|
||||
);
|
||||
|
||||
int tiny_update_matrices_with_derivatives(TinyCache* cache, tinytype delta_rho);
|
||||
int tiny_solve(TinySolver* solver);
|
||||
|
||||
int tiny_update_settings(
|
||||
TinySettings* settings,
|
||||
tinytype abs_pri_tol,
|
||||
tinytype abs_dua_tol,
|
||||
int max_iter,
|
||||
int check_termination,
|
||||
int en_state_bound,
|
||||
int en_input_bound,
|
||||
int en_state_soc,
|
||||
int en_input_soc,
|
||||
int en_state_linear,
|
||||
int en_input_linear
|
||||
);
|
||||
int tiny_set_default_settings(TinySettings* settings);
|
||||
|
||||
int tiny_set_x0(TinySolver* solver, tinyVector x0);
|
||||
int tiny_set_x_ref(TinySolver* solver, tinyMatrix x_ref);
|
||||
int tiny_set_u_ref(TinySolver* solver, tinyMatrix u_ref);
|
||||
|
||||
/**
|
||||
* Initialize sensitivity matrices for adaptive rho
|
||||
*
|
||||
* @param solver Pointer to solver
|
||||
*/
|
||||
void tiny_initialize_sensitivity_matrices(TinySolver* solver);
|
||||
|
||||
int tiny_setup_state_soc_constraints(
|
||||
TinySolver* solver,
|
||||
tinyVector Acx,
|
||||
tinyVector qcx,
|
||||
tinyVector cx,
|
||||
int numStateCones
|
||||
);
|
||||
|
||||
int tiny_setup_input_soc_constraints(
|
||||
TinySolver* solver,
|
||||
tinyVector Acu,
|
||||
tinyVector qcu,
|
||||
tinyVector cu,
|
||||
int numInputCones
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
// Default settings
|
||||
#define TINY_DEFAULT_ABS_PRI_TOL (1e-03)
|
||||
#define TINY_DEFAULT_ABS_DUA_TOL (1e-03)
|
||||
#define TINY_DEFAULT_MAX_ITER (1000)
|
||||
#define TINY_DEFAULT_CHECK_TERMINATION (1)
|
||||
#define TINY_DEFAULT_EN_STATE_BOUND (1)
|
||||
#define TINY_DEFAULT_EN_INPUT_BOUND (1)
|
||||
#define TINY_DEFAULT_EN_STATE_SOC (0)
|
||||
#define TINY_DEFAULT_EN_INPUT_SOC (0)
|
||||
#define TINY_DEFAULT_EN_STATE_LINEAR (0)
|
||||
#define TINY_DEFAULT_EN_INPUT_LINEAR (0)
|
||||
197
wust_vision-main/tasks/auto_aim/armor_control/tinympc/types.hpp
Normal file
197
wust_vision-main/tasks/auto_aim/armor_control/tinympc/types.hpp
Normal file
@@ -0,0 +1,197 @@
|
||||
#pragma once
|
||||
|
||||
#include <Eigen/Eigen>
|
||||
// #include <Eigen/Core>
|
||||
// #include <Eigen/LU>
|
||||
|
||||
using namespace Eigen;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef double tinytype; // should be double if you want to generate code
|
||||
typedef Matrix<tinytype, Dynamic, Dynamic> tinyMatrix;
|
||||
typedef Matrix<tinytype, Dynamic, 1> tinyVector;
|
||||
|
||||
// typedef Matrix<tinytype, NSTATES, 1> tiny_VectorNx;
|
||||
// typedef Matrix<tinytype, NINPUTS, 1> tiny_VectorNu;
|
||||
// typedef Matrix<tinytype, NSTATES, NSTATES> tiny_MatrixNxNx;
|
||||
// typedef Matrix<tinytype, NSTATES, NINPUTS> tiny_MatrixNxNu;
|
||||
// typedef Matrix<tinytype, NINPUTS, NSTATES> tiny_MatrixNuNx;
|
||||
// typedef Matrix<tinytype, NINPUTS, NINPUTS> tiny_MatrixNuNu;
|
||||
|
||||
// typedef Matrix<tinytype, NSTATES, NHORIZON> tiny_MatrixNxNh; // Nu x Nh
|
||||
// typedef Matrix<tinytype, NINPUTS, NHORIZON - 1> tiny_MatrixNuNhm1; // Nu x Nh-1
|
||||
|
||||
/**
|
||||
* Solution
|
||||
*/
|
||||
typedef struct {
|
||||
int iter;
|
||||
int solved;
|
||||
tinyMatrix x; // nx x N
|
||||
tinyMatrix u; // nu x N-1
|
||||
} TinySolution;
|
||||
|
||||
/**
|
||||
* Matrices that must be recomputed with changes in time step, rho
|
||||
*/
|
||||
typedef struct {
|
||||
tinytype rho;
|
||||
tinyMatrix Kinf; // nu x nx
|
||||
tinyMatrix Pinf; // nx x nx
|
||||
tinyMatrix Quu_inv; // nu x nu
|
||||
tinyMatrix AmBKt; // nx x nx
|
||||
tinyVector APf; // nx x 1
|
||||
tinyVector BPf; // nu x 1
|
||||
tinyMatrix C1; // From adaptive rho
|
||||
tinyMatrix C2; // From adaptive rho
|
||||
|
||||
// Sensitivity matrices for adaptive rho
|
||||
tinyMatrix dKinf_drho;
|
||||
tinyMatrix dPinf_drho;
|
||||
tinyMatrix dC1_drho;
|
||||
tinyMatrix dC2_drho;
|
||||
} TinyCache;
|
||||
/**
|
||||
* User settings
|
||||
*/
|
||||
typedef struct {
|
||||
tinytype abs_pri_tol;
|
||||
tinytype abs_dua_tol;
|
||||
int max_iter;
|
||||
int check_termination;
|
||||
int en_state_bound;
|
||||
int en_input_bound;
|
||||
int en_state_soc;
|
||||
int en_input_soc;
|
||||
int en_state_linear;
|
||||
int en_input_linear;
|
||||
|
||||
// Add adaptive rho parameters
|
||||
int adaptive_rho; // Enable/disable adaptive rho (1/0)
|
||||
tinytype adaptive_rho_min; // Minimum value for rho
|
||||
tinytype adaptive_rho_max; // Maximum value for rho
|
||||
int adaptive_rho_enable_clipping; // Enable/disable clipping of rho (1/0)
|
||||
} TinySettings;
|
||||
|
||||
/**
|
||||
* Problem variables
|
||||
*/
|
||||
typedef struct {
|
||||
int nx; // Number of states
|
||||
int nu; // Number of control inputs
|
||||
int N; // Number of knotpoints in the horizon
|
||||
|
||||
// State and input
|
||||
tinyMatrix x; // nx x N
|
||||
tinyMatrix u; // nu x N-1
|
||||
|
||||
// Linear control cost terms
|
||||
tinyMatrix q; // nx x N
|
||||
tinyMatrix r; // nu x N-1
|
||||
|
||||
// Linear Riccati backward pass terms
|
||||
tinyMatrix p; // nx x N
|
||||
tinyMatrix d; // nu x N-1
|
||||
|
||||
// Bound constraint variables
|
||||
// Slack variables
|
||||
tinyMatrix v; // nx x N
|
||||
tinyMatrix vnew; // nx x N
|
||||
tinyMatrix z; // nu x N-1
|
||||
tinyMatrix znew; // nu x N-1
|
||||
|
||||
// Dual variables
|
||||
tinyMatrix g; // nx x N
|
||||
tinyMatrix y; // nu x N-1
|
||||
|
||||
// State and input bounds
|
||||
tinyMatrix x_min; // nx x N
|
||||
tinyMatrix x_max; // nx x N
|
||||
tinyMatrix u_min; // nu x N-1
|
||||
tinyMatrix u_max; // nu x N-1
|
||||
|
||||
// Cone constraint variables
|
||||
// Variables to keep track of general cone information
|
||||
int numStateCones; // Number of cone constraints on states at each time step
|
||||
int numInputCones; // Number of cone constraints on inputs at each time step
|
||||
tinyVector cx; // One coefficient for each state cone
|
||||
tinyVector cu; // One coefficient for each input cone
|
||||
VectorXi Acx; // Start indices for each state cone
|
||||
VectorXi Acu; // Start indices for each input cone
|
||||
VectorXi qcx; // Dimension for each state cone
|
||||
VectorXi qcu; // Dimension for each input cone
|
||||
|
||||
// Slack variables
|
||||
tinyMatrix vc; // nx x N
|
||||
tinyMatrix vcnew; // nx x N
|
||||
tinyMatrix zc; // nu x N-1
|
||||
tinyMatrix zcnew; // nu x N-1
|
||||
|
||||
// Dual variables
|
||||
tinyMatrix gc; // nx x N
|
||||
tinyMatrix yc; // nu x N-1
|
||||
|
||||
// Linear constraint variables
|
||||
// Variables to keep track of general linear constraint information
|
||||
int numStateLinear; // Number of linear constraints on states at each time step
|
||||
int numInputLinear; // Number of linear constraints on inputs at each time step
|
||||
|
||||
// Constraint matrices and vectors
|
||||
tinyMatrix Alin_x; // Normal vectors for state linear constraints (numStateLinear x nx)
|
||||
tinyVector blin_x; // Offset values for state linear constraints (numStateLinear x 1)
|
||||
tinyMatrix Alin_u; // Normal vectors for input linear constraints (numInputLinear x nu)
|
||||
tinyVector blin_u; // Offset values for input linear constraints (numInputLinear x 1)
|
||||
|
||||
// Slack variables for linear constraints
|
||||
tinyMatrix vl; // nx x N
|
||||
tinyMatrix vlnew; // nx x N
|
||||
tinyMatrix zl; // nu x N-1
|
||||
tinyMatrix zlnew; // nu x N-1
|
||||
|
||||
// Dual variables for linear constraints
|
||||
tinyMatrix gl; // nx x N
|
||||
tinyMatrix yl; // nu x N-1
|
||||
|
||||
// Q, R, A, B, f given by user
|
||||
tinyVector Q; // nx x 1
|
||||
tinyVector R; // nu x 1
|
||||
tinyMatrix Adyn; // nx x nx (state transition matrix)
|
||||
tinyMatrix Bdyn; // nx x nu (control matrix)
|
||||
tinyVector fdyn; // nx x 1 (affine vector)
|
||||
|
||||
// Reference trajectory to track for one horizon
|
||||
tinyMatrix Xref; // nx x N
|
||||
tinyMatrix Uref; // nu x N-1
|
||||
|
||||
// Temporaries
|
||||
tinyVector Qu; // nu x 1
|
||||
|
||||
// Variables for keeping track of solve status
|
||||
tinytype primal_residual_state;
|
||||
tinytype primal_residual_input;
|
||||
tinytype dual_residual_state;
|
||||
tinytype dual_residual_input;
|
||||
int status;
|
||||
int iter;
|
||||
} TinyWorkspace;
|
||||
|
||||
/**
|
||||
* Main TinyMPC solver structure that holds all information.
|
||||
*/
|
||||
typedef struct {
|
||||
TinySolution* solution; // Solution
|
||||
TinySettings* settings; // Problem settings
|
||||
TinyCache* cache; // Problem cache
|
||||
TinyWorkspace* work; // Solver workspace
|
||||
} TinySolver;
|
||||
|
||||
// Add at the top with other definitions
|
||||
#define BENCH_NX 12
|
||||
#define BENCH_NU 4
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
111
wust_vision-main/tasks/auto_aim/armor_control/traj.hpp
Normal file
111
wust_vision-main/tasks/auto_aim/armor_control/traj.hpp
Normal file
@@ -0,0 +1,111 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <concepts>
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
|
||||
template<typename T>
|
||||
concept HasStaticLerp = requires(const T& a, const T& b, double t) {
|
||||
{
|
||||
T::lerp(a, b, t)
|
||||
} -> std::same_as<T>;
|
||||
};
|
||||
|
||||
template<HasStaticLerp PointT>
|
||||
class Trajectory {
|
||||
public:
|
||||
void reserve(size_t n) {
|
||||
cp_vec.reserve(n);
|
||||
dt_vec.reserve(n > 0 ? n - 1 : 0);
|
||||
prefix_time.reserve(n);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
cp_vec.clear();
|
||||
dt_vec.clear();
|
||||
prefix_time.clear();
|
||||
total_duration_ = 0.0;
|
||||
}
|
||||
|
||||
void push_back(const PointT& p, double dt = 0.0) {
|
||||
if (cp_vec.empty()) {
|
||||
cp_vec.push_back(p);
|
||||
prefix_time.push_back(0.0);
|
||||
total_duration_ = 0.0;
|
||||
return;
|
||||
}
|
||||
|
||||
assert(dt >= 0.0);
|
||||
|
||||
cp_vec.push_back(p);
|
||||
dt_vec.push_back(dt);
|
||||
|
||||
total_duration_ += dt;
|
||||
prefix_time.push_back(total_duration_);
|
||||
}
|
||||
|
||||
void set(const std::vector<PointT>& c, const std::vector<double>& t) {
|
||||
assert(!c.empty());
|
||||
assert(c.size() == t.size() + 1);
|
||||
|
||||
cp_vec = c;
|
||||
dt_vec = t;
|
||||
|
||||
prefix_time.resize(cp_vec.size());
|
||||
prefix_time[0] = 0.0;
|
||||
|
||||
for (size_t i = 0; i < dt_vec.size(); ++i)
|
||||
prefix_time[i + 1] = prefix_time[i] + dt_vec[i];
|
||||
|
||||
total_duration_ = prefix_time.back();
|
||||
}
|
||||
double getPrefixTimeAtIdx(int i) const {
|
||||
return prefix_time[i];
|
||||
}
|
||||
PointT getStateAtIdx(int i) const {
|
||||
return cp_vec[i];
|
||||
}
|
||||
|
||||
PointT getStateAtTime(double t) const {
|
||||
if (cp_vec.empty())
|
||||
return PointT {};
|
||||
|
||||
if (t <= 0.0)
|
||||
return cp_vec.front();
|
||||
|
||||
if (t >= total_duration_)
|
||||
return cp_vec.back();
|
||||
|
||||
auto it = std::lower_bound(prefix_time.begin(), prefix_time.end(), t);
|
||||
size_t i1 = std::distance(prefix_time.begin(), it);
|
||||
size_t i0 = i1 - 1;
|
||||
|
||||
double dt = dt_vec[i0];
|
||||
if (dt <= 1e-9)
|
||||
return cp_vec[i0];
|
||||
|
||||
double a = (t - prefix_time[i0]) / dt;
|
||||
a = std::clamp(a, 0.0, 1.0);
|
||||
|
||||
return PointT::lerp(cp_vec[i0], cp_vec[i1], a);
|
||||
}
|
||||
|
||||
double getTotalDuration() const {
|
||||
return total_duration_;
|
||||
}
|
||||
size_t size() const {
|
||||
return cp_vec.size();
|
||||
}
|
||||
|
||||
std::vector<PointT> cp_vec;
|
||||
std::vector<double> dt_vec;
|
||||
std::vector<double> prefix_time;
|
||||
double total_duration_ { 0.0 };
|
||||
};
|
||||
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
1355
wust_vision-main/tasks/auto_aim/armor_control/very_aimer.cpp
Normal file
1355
wust_vision-main/tasks/auto_aim/armor_control/very_aimer.cpp
Normal file
File diff suppressed because it is too large
Load Diff
27
wust_vision-main/tasks/auto_aim/armor_control/very_aimer.hpp
Normal file
27
wust_vision-main/tasks/auto_aim/armor_control/very_aimer.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
#include <memory>
|
||||
namespace wust_vl::common::utils {
|
||||
class Parameter;
|
||||
}
|
||||
using wust_vlParamterPtr = std::shared_ptr<wust_vl::common::utils::Parameter>;
|
||||
namespace wust_vision {
|
||||
struct GimbalCmd;
|
||||
}
|
||||
namespace wust_vision::auto_aim {
|
||||
enum class AutoAimFsm;
|
||||
|
||||
class Target;
|
||||
class VeryAimer {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<VeryAimer>;
|
||||
VeryAimer(wust_vlParamterPtr auto_aim_config_parameter);
|
||||
static Ptr create(wust_vlParamterPtr auto_aim_config_parameter) {
|
||||
return std::make_shared<VeryAimer>(auto_aim_config_parameter);
|
||||
};
|
||||
~VeryAimer();
|
||||
[[nodiscard]] GimbalCmd
|
||||
veryAim(Target target, double bullet_speed, const AutoAimFsm& auto_aim_fsm);
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace wust_vision::auto_aim
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
#include "tasks/type_common.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
struct LightParams {
|
||||
// width / height
|
||||
double min_ratio;
|
||||
double max_ratio;
|
||||
// vertical angle
|
||||
double max_angle;
|
||||
// judge color
|
||||
int color_diff_thresh;
|
||||
double max_angle_diff;
|
||||
int binary_thres;
|
||||
void load(const YAML::Node& config) {
|
||||
binary_thres = config["binary_thres"].as<int>();
|
||||
min_ratio = config["min_ratio"].as<double>();
|
||||
max_ratio = config["max_ratio"].as<double>();
|
||||
max_angle = config["max_angle"].as<double>();
|
||||
max_angle_diff = config["max_angle_diff"].as<double>();
|
||||
color_diff_thresh = config["color_diff_thresh"].as<int>();
|
||||
}
|
||||
};
|
||||
struct ArmorParams {
|
||||
double min_light_ratio;
|
||||
// light pairs distance
|
||||
double min_small_center_distance;
|
||||
double max_small_center_distance;
|
||||
double min_large_center_distance;
|
||||
double max_large_center_distance;
|
||||
// horizontal angle
|
||||
double max_angle;
|
||||
void load(const YAML::Node& config) {
|
||||
min_light_ratio = config["min_light_ratio"].as<double>();
|
||||
min_small_center_distance = config["min_small_center_distance"].as<double>();
|
||||
max_small_center_distance = config["max_small_center_distance"].as<double>();
|
||||
min_large_center_distance = config["min_large_center_distance"].as<double>();
|
||||
max_large_center_distance = config["max_large_center_distance"].as<double>();
|
||||
max_angle = config["max_angle"].as<double>();
|
||||
}
|
||||
};
|
||||
class ArmorDetectorBase {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorDetectorBase>;
|
||||
virtual ~ArmorDetectorBase() = default;
|
||||
|
||||
virtual void
|
||||
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) = 0;
|
||||
|
||||
using DetectorCallback =
|
||||
std::function<void(const std::vector<ArmorObject>&, const CommonFrame&)>;
|
||||
|
||||
virtual void setCallback(DetectorCallback cb) = 0;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,473 @@
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/number_classifier/factory.hpp"
|
||||
#include "tasks/utils/utils.hpp"
|
||||
#include "wust_vl/common/utils/timer.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
struct ArmorDetectorCommon::Impl {
|
||||
public:
|
||||
Impl(const YAML::Node& config) {
|
||||
params_.load(config);
|
||||
number_classifier_ = NumberClassifierFactory::createNumberClassifier(
|
||||
params_.classify_backend,
|
||||
params_.classify_model_path,
|
||||
params_.classify_label_path
|
||||
);
|
||||
}
|
||||
|
||||
bool extractNetImage(const cv::Mat& src, ArmorObject& armor) const noexcept {
|
||||
constexpr int light_length = 12;
|
||||
constexpr int warp_height = 28;
|
||||
constexpr int small_armor_width = 32;
|
||||
constexpr int large_armor_width = 54;
|
||||
const cv::Size roi_size(20, 28);
|
||||
|
||||
if (src.empty() || src.cols < 10 || src.rows < 10) {
|
||||
std::cerr << "[extractNetImage] input src is empty or too small!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto ordered = armor.sortCorners(armor.pts);
|
||||
|
||||
const cv::Point2f& p0 = ordered[0];
|
||||
const cv::Point2f& p1 = ordered[1];
|
||||
const cv::Point2f& p2 = ordered[2];
|
||||
const cv::Point2f& p3 = ordered[3];
|
||||
|
||||
const float l1_len = cv::norm(p1 - p0);
|
||||
const float l2_len = cv::norm(p2 - p3);
|
||||
const cv::Point2f c1 = (p0 + p1) * 0.5f;
|
||||
const cv::Point2f c2 = (p2 + p3) * 0.5f;
|
||||
|
||||
const float avg_light_len = 0.5f * (l1_len + l2_len);
|
||||
const float center_dist =
|
||||
avg_light_len > 1e-3f ? cv::norm(c1 - c2) / avg_light_len : 0.f;
|
||||
|
||||
const bool is_large = center_dist > params_.armor_params.min_large_center_distance;
|
||||
|
||||
const cv::Rect bbox = cv::boundingRect(armor.pts);
|
||||
if (bbox.width <= 0 || bbox.height <= 0)
|
||||
return false;
|
||||
|
||||
if (bbox.width > src.cols || bbox.height > src.rows)
|
||||
return false;
|
||||
const int dw = static_cast<int>(bbox.width * (params_.expand_ratio_w - 1.f));
|
||||
const int dh = static_cast<int>(bbox.height * (params_.expand_ratio_h - 1.f));
|
||||
|
||||
int new_x = bbox.x - (dw >> 1);
|
||||
int new_y = bbox.y - (dh >> 1);
|
||||
new_x = std::max(new_x, 0);
|
||||
new_y = std::max(new_y, 0);
|
||||
|
||||
int new_w = std::min(bbox.width + dw, src.cols - new_x);
|
||||
int new_h = std::min(bbox.height + dh, src.rows - new_y);
|
||||
|
||||
if (new_w <= 0 || new_h <= 0)
|
||||
return false;
|
||||
|
||||
const cv::Rect expanded_rect(new_x, new_y, new_w, new_h);
|
||||
|
||||
cv::Mat litroi_color = src(expanded_rect);
|
||||
if (litroi_color.empty())
|
||||
return false;
|
||||
|
||||
cv::Mat litroi_gray;
|
||||
try {
|
||||
cv::cvtColor(litroi_color, litroi_gray, cv::COLOR_BGR2GRAY);
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
|
||||
armor.whole_gray_img = litroi_gray;
|
||||
if (params_.enable_cv) {
|
||||
cv::Mat litroi_binary;
|
||||
try {
|
||||
cv::threshold(
|
||||
litroi_gray,
|
||||
litroi_binary,
|
||||
params_.light_params.binary_thres,
|
||||
255,
|
||||
cv::THRESH_BINARY
|
||||
);
|
||||
armor.whole_binary_img = litroi_binary;
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const cv::Point2f offset(static_cast<float>(new_x), static_cast<float>(new_y));
|
||||
|
||||
if (params_.enable_classify) {
|
||||
cv::Point2f src_vertices[4] = { armor.pts[1] - offset,
|
||||
armor.pts[0] - offset,
|
||||
armor.pts[3] - offset,
|
||||
armor.pts[2] - offset };
|
||||
|
||||
const int warp_width = is_large ? large_armor_width : small_armor_width;
|
||||
const int top_light_y = (warp_height - light_length) / 2 - 1;
|
||||
const int bottom_light_y = top_light_y + light_length;
|
||||
if (warp_width <= 0 || warp_height <= 0)
|
||||
return false;
|
||||
cv::Point2f dst_vertices[4] = {
|
||||
{ 0.f, static_cast<float>(bottom_light_y) },
|
||||
{ 0.f, static_cast<float>(top_light_y) },
|
||||
{ static_cast<float>(warp_width - 1), static_cast<float>(top_light_y) },
|
||||
{ static_cast<float>(warp_width - 1), static_cast<float>(bottom_light_y) }
|
||||
};
|
||||
|
||||
const cv::Mat warp_mat = cv::getPerspectiveTransform(src_vertices, dst_vertices);
|
||||
|
||||
cv::Mat number_image;
|
||||
cv::warpPerspective(
|
||||
litroi_gray,
|
||||
number_image,
|
||||
warp_mat,
|
||||
cv::Size(warp_width, warp_height),
|
||||
cv::INTER_LINEAR,
|
||||
cv::BORDER_CONSTANT,
|
||||
0
|
||||
);
|
||||
|
||||
const int roi_x = (warp_width - roi_size.width) >> 1;
|
||||
const cv::Rect num_roi(roi_x, 0, roi_size.width, roi_size.height);
|
||||
|
||||
if ((num_roi & cv::Rect(0, 0, warp_width, warp_height)) != num_roi)
|
||||
return false;
|
||||
|
||||
cv::Mat num_crop = number_image(num_roi);
|
||||
|
||||
cv::threshold(
|
||||
num_crop,
|
||||
armor.number_img,
|
||||
0,
|
||||
255,
|
||||
cv::THRESH_BINARY | cv::THRESH_OTSU
|
||||
);
|
||||
}
|
||||
|
||||
armor.whole_rgb_img = litroi_color;
|
||||
armor.local_offset = offset;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool refineLightsFromArmorPts(ArmorObject& armor) const noexcept {
|
||||
armor.center = (armor.pts[0] + armor.pts[1] + armor.pts[2] + armor.pts[3]) * 0.25f;
|
||||
|
||||
const int n_lights = static_cast<int>(armor.lights.size());
|
||||
if (n_lights < 2)
|
||||
return false;
|
||||
|
||||
const auto ordered = armor.sortCorners(armor.pts);
|
||||
|
||||
const cv::Point2f ref_centers[2] = { (ordered[0] + ordered[1]) * 0.5f,
|
||||
(ordered[2] + ordered[3]) * 0.5f };
|
||||
|
||||
int best0 = -1, best1 = -1;
|
||||
float best0_d2 = std::numeric_limits<float>::max();
|
||||
float best1_d2 = std::numeric_limits<float>::max();
|
||||
|
||||
for (int i = 0; i < n_lights; ++i) {
|
||||
const cv::Point2f& c = armor.lights[i].center;
|
||||
|
||||
const cv::Point2f d0 = c - ref_centers[0];
|
||||
const float dist0 = d0.dot(d0);
|
||||
if (dist0 < best0_d2) {
|
||||
best0_d2 = dist0;
|
||||
best0 = i;
|
||||
}
|
||||
|
||||
const cv::Point2f d1 = c - ref_centers[1];
|
||||
const float dist1 = d1.dot(d1);
|
||||
if (dist1 < best1_d2) {
|
||||
best1_d2 = dist1;
|
||||
best1 = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (best0 == best1) {
|
||||
best1 = -1;
|
||||
best1_d2 = std::numeric_limits<float>::max();
|
||||
|
||||
for (int i = 0; i < n_lights; ++i) {
|
||||
if (i == best0)
|
||||
continue;
|
||||
|
||||
const cv::Point2f d = armor.lights[i].center - ref_centers[1];
|
||||
const float dist = d.dot(d);
|
||||
if (dist < best1_d2) {
|
||||
best1_d2 = dist;
|
||||
best1 = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (best0 < 0 || best1 < 0)
|
||||
return false;
|
||||
|
||||
const auto& l0 = armor.lights[best0];
|
||||
const auto& l1 = armor.lights[best1];
|
||||
|
||||
if (l0.center.x < l1.center.x) {
|
||||
armor.lights[0] = l0;
|
||||
armor.lights[1] = l1;
|
||||
} else {
|
||||
armor.lights[0] = l1;
|
||||
armor.lights[1] = l0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<Light>
|
||||
findLights(const cv::Mat& color_img, const cv::Mat& binary_img, ArmorObject& armor)
|
||||
const noexcept {
|
||||
std::vector<std::vector<cv::Point>> contours;
|
||||
contours.reserve(64);
|
||||
|
||||
cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
|
||||
|
||||
std::vector<Light> all_lights;
|
||||
all_lights.reserve(contours.size());
|
||||
|
||||
for (const auto& contour: contours) {
|
||||
const int n = static_cast<int>(contour.size());
|
||||
if (n < 6)
|
||||
continue;
|
||||
|
||||
Light light(contour);
|
||||
if (!isLight(light))
|
||||
continue;
|
||||
|
||||
int sum_r = 0;
|
||||
int sum_b = 0;
|
||||
|
||||
for (const auto& pt: contour) {
|
||||
const cv::Vec3b* row = color_img.ptr<cv::Vec3b>(pt.y);
|
||||
const cv::Vec3b& pix = row[pt.x];
|
||||
sum_r += pix[0];
|
||||
sum_b += pix[2];
|
||||
}
|
||||
|
||||
const int avg_diff = std::abs(sum_r - sum_b) / n;
|
||||
if (avg_diff <= params_.light_params.color_diff_thresh)
|
||||
continue;
|
||||
|
||||
light.color = (sum_r > sum_b) ? 0 : 1; // 0=红, 1=蓝
|
||||
all_lights.emplace_back(std::move(light));
|
||||
}
|
||||
|
||||
std::sort(all_lights.begin(), all_lights.end(), [](const Light& a, const Light& b) {
|
||||
return a.center.x < b.center.x;
|
||||
});
|
||||
|
||||
armor.lights = all_lights;
|
||||
return all_lights;
|
||||
}
|
||||
|
||||
bool isLight(const Light& light) const noexcept {
|
||||
// width / length 比例
|
||||
const float ratio = light.width / light.length;
|
||||
|
||||
if (ratio <= params_.light_params.min_ratio || ratio >= params_.light_params.max_ratio)
|
||||
return false;
|
||||
|
||||
if (light.tilt_angle >= params_.light_params.max_angle)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
bool isArmor(const Light& l1, const Light& l2) const noexcept {
|
||||
const float len1 = l1.length;
|
||||
const float len2 = l2.length;
|
||||
if (len1 <= 1e-3f || len2 <= 1e-3f)
|
||||
return false;
|
||||
|
||||
const float min_len = (len1 < len2) ? len1 : len2;
|
||||
const float max_len = (len1 < len2) ? len2 : len1;
|
||||
if (min_len / max_len <= params_.armor_params.min_light_ratio)
|
||||
return false;
|
||||
|
||||
const cv::Point2f d = l1.center - l2.center;
|
||||
const float dist2 = d.dot(d);
|
||||
|
||||
const float avg_len = 0.5f * (len1 + len2);
|
||||
|
||||
const float min_small = params_.armor_params.min_small_center_distance * avg_len;
|
||||
const float max_small = params_.armor_params.max_small_center_distance * avg_len;
|
||||
const float min_large = params_.armor_params.min_large_center_distance * avg_len;
|
||||
const float max_large = params_.armor_params.max_large_center_distance * avg_len;
|
||||
|
||||
const float min_small2 = min_small * min_small;
|
||||
const float max_small2 = max_small * max_small;
|
||||
const float min_large2 = min_large * min_large;
|
||||
const float max_large2 = max_large * max_large;
|
||||
|
||||
const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2);
|
||||
const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2);
|
||||
|
||||
if (!(small_ok || large_ok))
|
||||
return false;
|
||||
|
||||
static const float tan_max_angle =
|
||||
std::tan(params_.armor_params.max_angle * CV_PI / 180.0f);
|
||||
|
||||
if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle)
|
||||
return false;
|
||||
|
||||
if (l1.color != l2.color)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<ArmorObject> detectNet(
|
||||
const cv::Mat& src_img,
|
||||
std::vector<ArmorObject>& objs_result,
|
||||
Eigen::Matrix3f transform_matrix,
|
||||
int detect_color,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) const noexcept {
|
||||
std::vector<ArmorObject> armors;
|
||||
|
||||
if (!src_img.data || src_img.empty()) {
|
||||
std::cout << "img data nullptr or empty" << std::endl;
|
||||
return armors;
|
||||
}
|
||||
|
||||
if (objs_result.empty()) {
|
||||
return armors;
|
||||
}
|
||||
|
||||
for (auto& armor_in: objs_result) {
|
||||
ArmorObject armor = armor_in;
|
||||
|
||||
if ((detect_color == 0 && armor.color == ArmorColor::BLUE)
|
||||
|| (detect_color == 1 && armor.color == ArmorColor::RED))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if (params_.enable_classify || params_.enable_cv) {
|
||||
bool ok = false;
|
||||
|
||||
ok = extractNetImage(src_img, armor);
|
||||
|
||||
if (!ok)
|
||||
continue;
|
||||
}
|
||||
|
||||
if (params_.enable_classify) {
|
||||
number_classifier_->classifyNumber(armor);
|
||||
if (armor.confidence < params_.classifier_threshold)
|
||||
continue;
|
||||
}
|
||||
|
||||
if (target_number.has_value()) {
|
||||
if (!isSameTarget(target_number.value(), armor.number)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (armor.color == ArmorColor::NONE || armor.color == ArmorColor::PURPLE) {
|
||||
armor.is_ok = false;
|
||||
armor.transform(transform_matrix);
|
||||
armors.emplace_back(armor);
|
||||
continue;
|
||||
}
|
||||
if (params_.enable_cv) {
|
||||
findLights(armor.whole_rgb_img, armor.whole_binary_img, armor);
|
||||
|
||||
if (refineLightsFromArmorPts(armor)) {
|
||||
if (isArmor(armor.lights[0], armor.lights[1])) {
|
||||
armor.is_ok = true;
|
||||
for (auto& light: armor.lights) {
|
||||
light.addOffset(armor.local_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (armor.is_ok) {
|
||||
armor.is_ok = armor.checkOkptsRight(params_.max_pts_error);
|
||||
}
|
||||
}
|
||||
|
||||
if (!armor.is_ok) {
|
||||
auto ordered = armor.sortCorners(armor.pts);
|
||||
Light l1, l2;
|
||||
l1.length = cv::norm(ordered[1] - ordered[0]);
|
||||
l1.center = (ordered[0] + ordered[1]) / 2.0;
|
||||
l2.length = cv::norm(ordered[2] - ordered[3]);
|
||||
l2.center = (ordered[2] + ordered[3]) / 2.0;
|
||||
if (!isArmor(l1, l2)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
armor.transform(transform_matrix);
|
||||
|
||||
armors.emplace_back(armor);
|
||||
}
|
||||
|
||||
return armors;
|
||||
}
|
||||
std::unique_ptr<NumberClassifierBase> number_classifier_;
|
||||
struct ArmorDetectCommonParams {
|
||||
std::string classify_backend = "opencv";
|
||||
std::string classify_model_path;
|
||||
std::string classify_label_path;
|
||||
double classifier_threshold = 0.5;
|
||||
LightParams light_params;
|
||||
ArmorParams armor_params;
|
||||
float expand_ratio_w = 1.1f;
|
||||
float expand_ratio_h = 1.1f;
|
||||
double max_pts_error = 20.0;
|
||||
bool enable_cv = false;
|
||||
bool enable_classify = true;
|
||||
void load(const YAML::Node& config) {
|
||||
expand_ratio_w = config["cv"]["light"]["expand_ratio_w"].as<float>(1.1);
|
||||
expand_ratio_h = config["cv"]["light"]["expand_ratio_h"].as<float>(1.1);
|
||||
max_pts_error = config["cv"]["light"]["max_pts_error"].as<double>(20.0);
|
||||
enable_cv = config["cv"]["enable"].as<bool>();
|
||||
light_params.load(config["cv"]["light"]);
|
||||
armor_params.load(config["cv"]["armor"]);
|
||||
enable_classify = config["classify"]["enable"].as<bool>();
|
||||
classify_model_path =
|
||||
utils::expandEnv(config["classify"]["model_path"].as<std::string>());
|
||||
classify_label_path =
|
||||
utils::expandEnv(config["classify"]["label_path"].as<std::string>());
|
||||
classify_backend = config["classify"]["backend"].as<std::string>();
|
||||
classifier_threshold = config["classify"]["threshold"].as<double>();
|
||||
}
|
||||
} params_;
|
||||
};
|
||||
ArmorDetectorCommon::ArmorDetectorCommon(const YAML::Node& config) {
|
||||
_impl = std::make_unique<Impl>(config);
|
||||
}
|
||||
ArmorDetectorCommon::~ArmorDetectorCommon() {
|
||||
_impl.reset();
|
||||
}
|
||||
std::vector<ArmorObject> ArmorDetectorCommon::detectNet(
|
||||
const cv::Mat& src_img,
|
||||
std::vector<ArmorObject>& objs_result,
|
||||
Eigen::Matrix3f transform_matrix,
|
||||
int detect_color,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) {
|
||||
return _impl
|
||||
->detectNet(src_img, objs_result, transform_matrix, detect_color, target_number);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorDetectorCommon {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorDetectorCommon>;
|
||||
ArmorDetectorCommon(const YAML::Node& config);
|
||||
static Ptr create(const YAML::Node& config) {
|
||||
return std::make_unique<ArmorDetectorCommon>(config);
|
||||
}
|
||||
~ArmorDetectorCommon();
|
||||
|
||||
std::vector<ArmorObject> detectNet(
|
||||
const cv::Mat& src_img,
|
||||
std::vector<ArmorObject>& objs_result,
|
||||
Eigen::Matrix3f transform_matrix,
|
||||
int detect_color,
|
||||
const std::optional<ArmorNumber>& target_number = std::nullopt
|
||||
);
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,134 @@
|
||||
// Copyright 2025 XiaoJian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "armor_detector_base.hpp"
|
||||
#include "tasks/utils/config.hpp"
|
||||
#include <string>
|
||||
#include <yaml-cpp/yaml.h>
|
||||
|
||||
#ifdef USE_OPENVINO
|
||||
#include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp"
|
||||
#endif
|
||||
#ifdef USE_TRT
|
||||
#include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp"
|
||||
#endif
|
||||
#ifdef USE_NCNN
|
||||
#include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp"
|
||||
#endif
|
||||
#ifdef USE_ORT
|
||||
#include "tasks/auto_aim/armor_detect/onnxruntime/armor_detector_onnxruntime.hpp"
|
||||
#endif
|
||||
#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp"
|
||||
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
|
||||
class DetectorFactory {
|
||||
public:
|
||||
static ArmorDetectorBase::Ptr createArmorDetector(
|
||||
const std::string& backend,
|
||||
bool use_armor_detect_common,
|
||||
std::string cv_config_path = OPENCV_CONFIG,
|
||||
std::string ml_config_path = ML_CONFIG
|
||||
) {
|
||||
// 检查编译时是否支持
|
||||
auto isBackendEnabled = [&backend]() -> bool {
|
||||
#ifdef USE_OPENVINO
|
||||
if (backend == "openvino")
|
||||
return true;
|
||||
#endif
|
||||
#ifdef USE_TRT
|
||||
if (backend == "tensorrt")
|
||||
return true;
|
||||
#endif
|
||||
#ifdef USE_NCNN
|
||||
if (backend == "ncnn")
|
||||
return true;
|
||||
#endif
|
||||
#ifdef USE_ORT
|
||||
if (backend == "onnxruntime")
|
||||
return true;
|
||||
#endif
|
||||
if (backend == "opencv")
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
if (!isBackendEnabled()) {
|
||||
std::cout << "Backend " << backend << " is not enabled at compile time."
|
||||
<< std::endl;
|
||||
throw std::runtime_error("Backend " + backend + " is not enabled at compile time.");
|
||||
}
|
||||
|
||||
auto getConfigPath = [&](const std::string& backend) -> std::string {
|
||||
if (backend == "opencv")
|
||||
return cv_config_path;
|
||||
else
|
||||
return ml_config_path;
|
||||
};
|
||||
|
||||
std::string config_path = getConfigPath(backend);
|
||||
if (config_path.empty()) {
|
||||
std::cout << "No config path for backend: " << backend << std::endl;
|
||||
throw std::runtime_error("No config path for backend: " + backend);
|
||||
}
|
||||
|
||||
YAML::Node armor_detect_config = YAML::LoadFile(config_path);
|
||||
|
||||
// 创建对应后端实例
|
||||
#if defined(USE_OPENVINO)
|
||||
if (backend == "openvino") {
|
||||
return ArmorDetectorOpenVino::create(
|
||||
armor_detect_config["armor_detector"],
|
||||
use_armor_detect_common
|
||||
);
|
||||
}
|
||||
#endif
|
||||
#if defined(USE_TRT)
|
||||
if (backend == "tensorrt") {
|
||||
return ArmorDetectorTrt::create(
|
||||
armor_detect_config["armor_detector"],
|
||||
use_armor_detect_common
|
||||
);
|
||||
}
|
||||
#endif
|
||||
#if defined(USE_NCNN)
|
||||
if (backend == "ncnn") {
|
||||
return ArmorDetectorNCNN::create(
|
||||
armor_detect_config["armor_detector"],
|
||||
use_armor_detect_common
|
||||
);
|
||||
}
|
||||
#endif
|
||||
#if defined(USE_ORT)
|
||||
if (backend == "onnxruntime") {
|
||||
return ArmorDetectorOnnxRuntime::create(
|
||||
armor_detect_config["armor_detector"],
|
||||
use_armor_detect_common
|
||||
);
|
||||
}
|
||||
#endif
|
||||
if (backend == "opencv") {
|
||||
return ArmorDetectorOpenCV::create(armor_detect_config["armor_detector"]);
|
||||
}
|
||||
std::cout << "Unsupported armor detector backend (or not compiled): " << backend
|
||||
<< std::endl;
|
||||
throw std::runtime_error(
|
||||
"Unsupported armor detector backend (or not compiled): " + backend
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
198
wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.cpp
Normal file
198
wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.cpp
Normal file
@@ -0,0 +1,198 @@
|
||||
#include "armor_infer.hpp"
|
||||
namespace wust_vision::auto_aim::armor_infer {
|
||||
struct GridAndStride {
|
||||
int grid0;
|
||||
int grid1;
|
||||
int stride;
|
||||
};
|
||||
[[nodiscard]] static inline std::vector<GridAndStride>
|
||||
generate_grids_and_stride(int target_w, int target_h, const std::vector<int>& strides) noexcept {
|
||||
std::vector<GridAndStride> grid_strides;
|
||||
for (int stride: strides) {
|
||||
const int num_w = target_w / stride;
|
||||
const int num_h = target_h / stride;
|
||||
grid_strides.reserve(grid_strides.size() + num_w * num_h);
|
||||
for (int gy = 0; gy < num_h; ++gy) {
|
||||
for (int gx = 0; gx < num_w; ++gx) {
|
||||
grid_strides.push_back(GridAndStride { gx, gy, stride });
|
||||
}
|
||||
}
|
||||
}
|
||||
return grid_strides;
|
||||
}
|
||||
std::vector<ArmorObject> ArmorInfer::postProcessTUP_impl(const cv::Mat& out) const {
|
||||
static std::optional<std::vector<GridAndStride>> _grid_strides;
|
||||
if (!_grid_strides) {
|
||||
_grid_strides = generate_grids_and_stride(inputW(), inputH(), { 8, 16, 32 });
|
||||
}
|
||||
const auto& grid_strides = _grid_strides.value();
|
||||
std::vector<ArmorObject> out_objs;
|
||||
const int num_anchors =
|
||||
static_cast<int>(std::min<size_t>(grid_strides.size(), static_cast<size_t>(out.rows)));
|
||||
for (int a = 0; a < num_anchors; ++a) {
|
||||
const float confidence = out.at<float>(a, 8);
|
||||
if (confidence < conf_threshold_)
|
||||
continue;
|
||||
|
||||
const auto& gs = grid_strides[a];
|
||||
const int gx = gs.grid0, gy = gs.grid1, stride = gs.stride;
|
||||
|
||||
// color & class
|
||||
const int color_offset = 9;
|
||||
const int num_colors = ModelTraits<Mode::TUP>::NUM_COLORS;
|
||||
const int num_classes = ModelTraits<Mode::TUP>::NUM_CLASSES;
|
||||
|
||||
cv::Mat color_scores = out.row(a).colRange(color_offset, color_offset + num_colors);
|
||||
cv::Mat class_scores =
|
||||
out.row(a).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
|
||||
|
||||
double max_color, max_class;
|
||||
cv::Point color_id, class_id;
|
||||
cv::minMaxLoc(color_scores, nullptr, &max_color, nullptr, &color_id);
|
||||
cv::minMaxLoc(class_scores, nullptr, &max_class, nullptr, &class_id);
|
||||
|
||||
const float x1 = (out.at<float>(a, 0) + gx) * stride;
|
||||
const float y1 = (out.at<float>(a, 1) + gy) * stride;
|
||||
const float x2 = (out.at<float>(a, 2) + gx) * stride;
|
||||
const float y2 = (out.at<float>(a, 3) + gy) * stride;
|
||||
const float x3 = (out.at<float>(a, 4) + gx) * stride;
|
||||
const float y3 = (out.at<float>(a, 5) + gy) * stride;
|
||||
const float x4 = (out.at<float>(a, 6) + gx) * stride;
|
||||
const float y4 = (out.at<float>(a, 7) + gy) * stride;
|
||||
|
||||
ArmorObject obj;
|
||||
obj.pts = { cv::Point2f(x1, y1),
|
||||
cv::Point2f(x2, y2),
|
||||
cv::Point2f(x3, y3),
|
||||
cv::Point2f(x4, y4) };
|
||||
obj.box = cv::boundingRect(obj.pts);
|
||||
obj.color = static_cast<ArmorColor>(color_id.x);
|
||||
obj.number = static_cast<ArmorNumber>(class_id.x);
|
||||
obj.confidence = confidence;
|
||||
out_objs.push_back(std::move(obj));
|
||||
}
|
||||
return topKAndNms(out_objs, top_k_, nms_threshold_);
|
||||
}
|
||||
|
||||
std::vector<ArmorObject> ArmorInfer::postProcessRP_impl(const cv::Mat& out) const {
|
||||
std::vector<ArmorObject> out_objs;
|
||||
const int rows = out.rows;
|
||||
const int color_offset = 9;
|
||||
const int num_colors = ModelTraits<Mode::RP>::NUM_COLORS;
|
||||
const int num_classes = ModelTraits<Mode::RP>::NUM_CLASSES;
|
||||
|
||||
for (int r = 0; r < rows; ++r) {
|
||||
float conf_raw = out.at<float>(r, 8);
|
||||
const float confidence = static_cast<float>(sigmoid(conf_raw));
|
||||
if (confidence < conf_threshold_)
|
||||
continue;
|
||||
|
||||
cv::Mat color_scores = out.row(r).colRange(color_offset, color_offset + num_colors);
|
||||
cv::Mat class_scores =
|
||||
out.row(r).colRange(color_offset + num_colors, color_offset + num_colors + num_classes);
|
||||
|
||||
double max_color_score, max_class_score;
|
||||
cv::Point color_id, class_id;
|
||||
cv::minMaxLoc(color_scores, nullptr, &max_color_score, nullptr, &color_id);
|
||||
cv::minMaxLoc(class_scores, nullptr, &max_class_score, nullptr, &class_id);
|
||||
|
||||
ArmorObject obj;
|
||||
obj.pts.resize(4);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const float x = out.at<float>(r, 0 + k * 2);
|
||||
const float y = out.at<float>(r, 1 + k * 2);
|
||||
obj.pts[k] = cv::Point2f(x, y);
|
||||
}
|
||||
obj.box = cv::boundingRect(obj.pts);
|
||||
|
||||
obj.color = static_cast<ArmorColor>(color_id.x);
|
||||
|
||||
obj.number = static_cast<ArmorNumber>(class_id.x);
|
||||
obj.confidence = confidence;
|
||||
out_objs.push_back(std::move(obj));
|
||||
}
|
||||
|
||||
return topKAndNms(out_objs, top_k_, nms_threshold_);
|
||||
}
|
||||
|
||||
std::vector<ArmorObject> ArmorInfer::postProcessAT_impl(const cv::Mat& out) const {
|
||||
std::vector<ArmorObject> out_objs;
|
||||
|
||||
constexpr int nkpt = ModelTraits<Mode::AT>::NUM_KPTS;
|
||||
constexpr int nk = nkpt * 2; // keypoints flattened
|
||||
auto max_det = out.rows;
|
||||
auto det_dim = out.cols;
|
||||
auto output_ptr = out.ptr<float>();
|
||||
for (int i = 0; i < max_det; ++i) {
|
||||
const float* row = output_ptr + i * det_dim;
|
||||
float conf = row[4];
|
||||
if (!std::isfinite(conf) || conf < conf_threshold_)
|
||||
continue;
|
||||
float x1 = row[0];
|
||||
float y1 = row[1];
|
||||
float x2 = row[2];
|
||||
float y2 = row[3];
|
||||
int cls = static_cast<int>(row[5]);
|
||||
|
||||
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|
||||
|| x2 <= x1 || y2 <= y1)
|
||||
continue;
|
||||
|
||||
ArmorObject obj;
|
||||
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
|
||||
obj.confidence = conf;
|
||||
auto color_num = ModelTraits<Mode::AT>::CLASSES[cls];
|
||||
obj.color = color_num.first;
|
||||
obj.number = color_num.second;
|
||||
|
||||
obj.pts.reserve(nkpt);
|
||||
for (int k = 0; k < nkpt; ++k) {
|
||||
float kx = row[6 + 2 * k];
|
||||
float ky = row[6 + 2 * k + 1];
|
||||
obj.pts.emplace_back(kx, ky);
|
||||
}
|
||||
|
||||
out_objs.emplace_back(std::move(obj));
|
||||
}
|
||||
|
||||
return out_objs;
|
||||
}
|
||||
std::vector<ArmorObject> ArmorInfer::postProcessBOX_impl(const cv::Mat& out) const {
|
||||
std::vector<ArmorObject> out_objs;
|
||||
auto max_det = out.rows;
|
||||
auto det_dim = out.cols;
|
||||
auto output_ptr = out.ptr<float>();
|
||||
for (int i = 0; i < max_det; ++i) {
|
||||
const float* row = output_ptr + i * det_dim;
|
||||
float conf = row[4];
|
||||
if (!std::isfinite(conf) || conf < conf_threshold_)
|
||||
continue;
|
||||
float x1 = row[0];
|
||||
float y1 = row[1];
|
||||
float x2 = row[2];
|
||||
float y2 = row[3];
|
||||
int cls = static_cast<int>(row[5]);
|
||||
|
||||
if (!std::isfinite(x1) || !std::isfinite(y1) || !std::isfinite(x2) || !std::isfinite(y2)
|
||||
|| x2 <= x1 || y2 <= y1)
|
||||
continue;
|
||||
|
||||
ArmorObject obj;
|
||||
obj.box = cv::Rect2f(x1, y1, x2 - x1, y2 - y1);
|
||||
obj.confidence = conf;
|
||||
auto color_num = ModelTraits<Mode::BOX>::CLASSES[cls];
|
||||
obj.color = color_num.first;
|
||||
obj.number = color_num.second;
|
||||
std::vector<cv::Point2f> pts;
|
||||
pts.resize(4);
|
||||
pts[0] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y + obj.box.height); // 右下
|
||||
pts[1] = cv::Point2f(obj.box.x + obj.box.width, obj.box.y); // 右上
|
||||
pts[2] = cv::Point2f(obj.box.x, obj.box.y); // 左上
|
||||
pts[3] = cv::Point2f(obj.box.x, obj.box.y + obj.box.height); // 左下
|
||||
obj.pts = std::move(pts);
|
||||
|
||||
out_objs.emplace_back(std::move(obj));
|
||||
}
|
||||
return out_objs;
|
||||
}
|
||||
} // namespace wust_vision::auto_aim::armor_infer
|
||||
324
wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.hpp
Normal file
324
wust_vision-main/tasks/auto_aim/armor_detect/armor_infer.hpp
Normal file
@@ -0,0 +1,324 @@
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
|
||||
namespace wust_vision::auto_aim::armor_infer {
|
||||
|
||||
static constexpr float MERGE_CONF_ERROR = 0.15f;
|
||||
static constexpr float MERGE_MIN_IOU = 0.9f;
|
||||
|
||||
enum class Mode { TUP, RP, AT, BOX416, BOX320, BOX };
|
||||
|
||||
inline Mode modeFromString(const std::string& m) {
|
||||
if (m == "tup" || m == "TUP")
|
||||
return Mode::TUP;
|
||||
if (m == "rp" || m == "RP")
|
||||
return Mode::RP;
|
||||
if (m == "at" || m == "AT")
|
||||
return Mode::AT;
|
||||
if (m == "box416" || m == "BOX416")
|
||||
return Mode::BOX416;
|
||||
if (m == "box320" || m == "BOX320")
|
||||
return Mode::BOX320;
|
||||
return Mode::TUP;
|
||||
}
|
||||
|
||||
// ------------------------- model traits -------------------------
|
||||
template<Mode M>
|
||||
struct ModelTraits; // declare
|
||||
// TUP
|
||||
template<>
|
||||
struct ModelTraits<Mode::TUP> {
|
||||
static constexpr int INPUT_W = 416;
|
||||
static constexpr int INPUT_H = 416;
|
||||
static constexpr int NUM_CLASSES = 8;
|
||||
static constexpr int NUM_COLORS = 4;
|
||||
static constexpr bool USE_NORM = false;
|
||||
static constexpr bool INPUT_RGB = true;
|
||||
};
|
||||
|
||||
// RP
|
||||
template<>
|
||||
struct ModelTraits<Mode::RP> {
|
||||
static constexpr int INPUT_W = 640;
|
||||
static constexpr int INPUT_H = 640;
|
||||
static constexpr int NUM_CLASSES = 9;
|
||||
static constexpr int NUM_COLORS = 4;
|
||||
static constexpr bool USE_NORM = true;
|
||||
static constexpr bool INPUT_RGB = false;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ModelTraits<Mode::AT> {
|
||||
static constexpr int INPUT_W = 640;
|
||||
static constexpr int INPUT_H = 640;
|
||||
static constexpr int NUM_KPTS = 4;
|
||||
static constexpr bool USE_NORM = true;
|
||||
static constexpr bool INPUT_RGB = false;
|
||||
static constexpr std::array<std::pair<ArmorColor, ArmorNumber>, 64> CLASSES = { {
|
||||
{ ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE },
|
||||
{ ArmorColor::BLUE, ArmorNumber::SENTRY }, { ArmorColor::BLUE, ArmorNumber::NO1 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO2 }, { ArmorColor::BLUE, ArmorNumber::NO3 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO4 }, { ArmorColor::BLUE, ArmorNumber::NO5 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::OUTPOST }, { ArmorColor::BLUE, ArmorNumber::BASE },
|
||||
{ ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 },
|
||||
{ ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE },
|
||||
{ ArmorColor::RED, ArmorNumber::SENTRY }, { ArmorColor::RED, ArmorNumber::NO1 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO2 }, { ArmorColor::RED, ArmorNumber::NO3 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO4 }, { ArmorColor::RED, ArmorNumber::NO5 },
|
||||
{ ArmorColor::RED, ArmorNumber::OUTPOST }, { ArmorColor::RED, ArmorNumber::BASE },
|
||||
{ ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 },
|
||||
{ ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 },
|
||||
{ ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 },
|
||||
{ ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE },
|
||||
{ ArmorColor::NONE, ArmorNumber::SENTRY }, { ArmorColor::NONE, ArmorNumber::NO1 },
|
||||
{ ArmorColor::NONE, ArmorNumber::NO2 }, { ArmorColor::NONE, ArmorNumber::NO3 },
|
||||
{ ArmorColor::NONE, ArmorNumber::NO4 }, { ArmorColor::NONE, ArmorNumber::NO5 },
|
||||
{ ArmorColor::NONE, ArmorNumber::OUTPOST }, { ArmorColor::NONE, ArmorNumber::BASE },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::SENTRY }, { ArmorColor::PURPLE, ArmorNumber::NO1 },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::NO2 }, { ArmorColor::PURPLE, ArmorNumber::NO3 },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::NO4 }, { ArmorColor::PURPLE, ArmorNumber::NO5 },
|
||||
{ ArmorColor::PURPLE, ArmorNumber::OUTPOST }, { ArmorColor::PURPLE, ArmorNumber::BASE },
|
||||
} };
|
||||
};
|
||||
template<>
|
||||
struct ModelTraits<Mode::BOX416> {
|
||||
static constexpr int INPUT_W = 416;
|
||||
static constexpr int INPUT_H = 416;
|
||||
};
|
||||
template<>
|
||||
struct ModelTraits<Mode::BOX320> {
|
||||
static constexpr int INPUT_W = 320;
|
||||
static constexpr int INPUT_H = 320;
|
||||
};
|
||||
template<>
|
||||
struct ModelTraits<Mode::BOX> {
|
||||
static constexpr int INPUT_W = 416;
|
||||
static constexpr int INPUT_H = 416;
|
||||
static constexpr bool INPUT_RGB = false;
|
||||
static constexpr bool USE_NORM = true;
|
||||
static constexpr std::array<std::pair<ArmorColor, ArmorNumber>, 12> CLASSES = { {
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO1 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO2 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO3 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO4 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::NO5 },
|
||||
{ ArmorColor::BLUE, ArmorNumber::SENTRY },
|
||||
{ ArmorColor::RED, ArmorNumber::NO1 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO2 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO3 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO4 },
|
||||
{ ArmorColor::RED, ArmorNumber::NO5 },
|
||||
{ ArmorColor::RED, ArmorNumber::SENTRY },
|
||||
|
||||
} };
|
||||
};
|
||||
|
||||
[[nodiscard]] inline double sigmoid(double x) noexcept {
|
||||
return x >= 0 ? 1.0 / (1.0 + std::exp(-x)) : std::exp(x) / (1.0 + std::exp(x));
|
||||
}
|
||||
|
||||
[[nodiscard]] inline float rectIoU(const cv::Rect2f& a, const cv::Rect2f& b) noexcept {
|
||||
const cv::Rect2f inter = a & b;
|
||||
const float inter_area = inter.area();
|
||||
const float union_area = a.area() + b.area() - inter_area;
|
||||
if (union_area <= 0.f || std::isnan(union_area))
|
||||
return 0.f;
|
||||
return inter_area / union_area;
|
||||
}
|
||||
|
||||
// Merge / NMS helpers that mimic original intent but clearer
|
||||
inline void nms_merge_sorted_bboxes(
|
||||
std::vector<ArmorObject>& objs,
|
||||
std::vector<int>& out_indices,
|
||||
float nms_threshold
|
||||
) {
|
||||
out_indices.clear();
|
||||
const size_t n = objs.size();
|
||||
std::vector<float> areas(n);
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
areas[i] = objs[i].box.area();
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
ArmorObject& a = objs[i];
|
||||
bool keep = true;
|
||||
for (int idx: out_indices) {
|
||||
ArmorObject& b = objs[idx];
|
||||
const float iou = rectIoU(a.box, b.box);
|
||||
if (std::isnan(iou) || iou > nms_threshold) {
|
||||
keep = false;
|
||||
if (a.number == b.number && a.color == b.color && iou > MERGE_MIN_IOU
|
||||
&& std::abs(a.confidence - b.confidence) < MERGE_CONF_ERROR)
|
||||
{
|
||||
// accumulate points for later averaging
|
||||
for (const auto& pt: a.pts)
|
||||
b.pts.push_back(pt);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (keep)
|
||||
out_indices.push_back(static_cast<int>(i));
|
||||
}
|
||||
}
|
||||
|
||||
inline std::vector<ArmorObject>
|
||||
topKAndNms(std::vector<ArmorObject>& objs, int top_k, float nms_threshold) {
|
||||
std::sort(objs.begin(), objs.end(), [](const ArmorObject& a, const ArmorObject& b) {
|
||||
return a.confidence > b.confidence;
|
||||
});
|
||||
if (static_cast<int>(objs.size()) > top_k)
|
||||
objs.resize(static_cast<size_t>(top_k));
|
||||
|
||||
std::vector<int> indices;
|
||||
nms_merge_sorted_bboxes(objs, indices, nms_threshold);
|
||||
|
||||
std::vector<ArmorObject> result;
|
||||
result.reserve(indices.size());
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
result.push_back(std::move(objs[indices[i]]));
|
||||
// average merged extra points if any
|
||||
auto& ro = result.back();
|
||||
if (ro.pts.size() >= 8) {
|
||||
const size_t npts = ro.pts.size();
|
||||
std::array<cv::Point2f, 4> accum { { { 0, 0 }, { 0, 0 }, { 0, 0 }, { 0, 0 } } };
|
||||
for (size_t j = 0; j < npts; ++j)
|
||||
accum[j % 4] += ro.pts[j];
|
||||
ro.pts.resize(4);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
float denom = static_cast<float>(npts) / 4.0f;
|
||||
ro.pts[k].x = accum[k].x / denom;
|
||||
ro.pts[k].y = accum[k].y / denom;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// However providing full modern unified class below that delegates using templated helpers.
|
||||
|
||||
class ArmorInfer {
|
||||
public:
|
||||
ArmorInfer(
|
||||
Mode mode = Mode::TUP,
|
||||
float conf_threshold = 0.25f,
|
||||
float nms_threshold = 0.45f,
|
||||
int top_k = 100
|
||||
) noexcept:
|
||||
mode_(mode),
|
||||
conf_threshold_(conf_threshold),
|
||||
nms_threshold_(nms_threshold),
|
||||
top_k_(top_k) {
|
||||
setMode(mode_);
|
||||
}
|
||||
|
||||
void setMode(Mode m) noexcept {
|
||||
mode_ = m;
|
||||
switch (mode_) {
|
||||
case Mode::TUP: {
|
||||
input_w_ = ModelTraits<Mode::TUP>::INPUT_W;
|
||||
input_h_ = ModelTraits<Mode::TUP>::INPUT_H;
|
||||
use_norm_ = ModelTraits<Mode::TUP>::USE_NORM;
|
||||
input_rgb_ = ModelTraits<Mode::TUP>::INPUT_RGB;
|
||||
break;
|
||||
}
|
||||
case Mode::RP: {
|
||||
input_w_ = ModelTraits<Mode::RP>::INPUT_W;
|
||||
input_h_ = ModelTraits<Mode::RP>::INPUT_H;
|
||||
use_norm_ = ModelTraits<Mode::RP>::USE_NORM;
|
||||
input_rgb_ = ModelTraits<Mode::RP>::INPUT_RGB;
|
||||
break;
|
||||
}
|
||||
case Mode::AT: {
|
||||
input_w_ = ModelTraits<Mode::AT>::INPUT_W;
|
||||
input_h_ = ModelTraits<Mode::AT>::INPUT_H;
|
||||
use_norm_ = ModelTraits<Mode::AT>::USE_NORM;
|
||||
input_rgb_ = ModelTraits<Mode::AT>::INPUT_RGB;
|
||||
break;
|
||||
}
|
||||
case Mode::BOX416: {
|
||||
input_w_ = ModelTraits<Mode::BOX416>::INPUT_W;
|
||||
input_h_ = ModelTraits<Mode::BOX416>::INPUT_H;
|
||||
use_norm_ = ModelTraits<Mode::BOX>::USE_NORM;
|
||||
input_rgb_ = ModelTraits<Mode::BOX>::INPUT_RGB;
|
||||
break;
|
||||
}
|
||||
case Mode::BOX320: {
|
||||
input_w_ = ModelTraits<Mode::BOX320>::INPUT_W;
|
||||
input_h_ = ModelTraits<Mode::BOX320>::INPUT_H;
|
||||
use_norm_ = ModelTraits<Mode::BOX>::USE_NORM;
|
||||
input_rgb_ = ModelTraits<Mode::BOX>::INPUT_RGB;
|
||||
break;
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void setConfThreshold(float t) noexcept {
|
||||
conf_threshold_ = t;
|
||||
}
|
||||
void setNmsThreshold(float t) noexcept {
|
||||
nms_threshold_ = t;
|
||||
}
|
||||
void setTopK(int k) noexcept {
|
||||
top_k_ = k;
|
||||
}
|
||||
|
||||
int inputW() const noexcept {
|
||||
return input_w_;
|
||||
}
|
||||
int inputH() const noexcept {
|
||||
return input_h_;
|
||||
}
|
||||
bool useNorm() const noexcept {
|
||||
return use_norm_;
|
||||
}
|
||||
bool inputRGB() const noexcept {
|
||||
return input_rgb_;
|
||||
}
|
||||
|
||||
// main dispatching entry (keeps original signature)
|
||||
[[nodiscard]] std::vector<ArmorObject> postProcess(const cv::Mat& output_buffer) const {
|
||||
switch (mode_) {
|
||||
case Mode::TUP:
|
||||
return postProcessTUP_impl(output_buffer);
|
||||
case Mode::RP:
|
||||
return postProcessRP_impl(output_buffer);
|
||||
case Mode::AT:
|
||||
return postProcessAT_impl(output_buffer);
|
||||
case Mode::BOX416:
|
||||
return postProcessBOX_impl(output_buffer);
|
||||
case Mode::BOX320:
|
||||
return postProcessBOX_impl(output_buffer);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<ArmorObject> postProcessTUP_impl(const cv::Mat& out) const;
|
||||
|
||||
std::vector<ArmorObject> postProcessRP_impl(const cv::Mat& out) const;
|
||||
|
||||
std::vector<ArmorObject> postProcessAT_impl(const cv::Mat& out) const;
|
||||
|
||||
std::vector<ArmorObject> postProcessBOX_impl(const cv::Mat& out) const;
|
||||
|
||||
private:
|
||||
Mode mode_;
|
||||
int input_w_ { 0 };
|
||||
int input_h_ { 0 };
|
||||
float conf_threshold_ { 0.25f };
|
||||
float nms_threshold_ { 0.45f };
|
||||
int top_k_ { 100 };
|
||||
bool use_norm_ { false };
|
||||
bool input_rgb_ { false };
|
||||
};
|
||||
|
||||
} // namespace wust_vision::auto_aim::armor_infer
|
||||
@@ -0,0 +1,222 @@
|
||||
|
||||
#ifdef USE_NCNN
|
||||
#include "tasks/auto_aim/armor_detect/ncnn/armor_detector_ncnn.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
|
||||
#include "wust_vl/ml_net/ncnn/ncnn_net.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
|
||||
struct ArmorDetectorNCNN::Impl {
|
||||
public:
|
||||
Impl(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
if (use_armor_detect_common) {
|
||||
armor_detect_common_ = ArmorDetectorCommon::create(config);
|
||||
}
|
||||
std::string model_type = config["ncnn"]["model_type"].as<std::string>();
|
||||
auto model = armor_infer::modeFromString(model_type);
|
||||
float conf_threshold = config["ncnn"]["conf_threshold"].as<float>();
|
||||
int top_k = config["ncnn"]["top_k"].as<int>();
|
||||
float nms_threshold = config["ncnn"]["nms_threshold"].as<float>();
|
||||
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
|
||||
model,
|
||||
conf_threshold,
|
||||
nms_threshold,
|
||||
top_k
|
||||
);
|
||||
std::string model_path_param =
|
||||
utils::expandEnv(config["ncnn"]["model_path_param"].as<std::string>());
|
||||
std::string model_path_bin =
|
||||
utils::expandEnv(config["ncnn"]["model_path_bin"].as<std::string>());
|
||||
bool use_gpu = config["ncnn"]["use_gpu"].as<bool>();
|
||||
int cpu_threads = config["ncnn"]["cpu_threads"].as<int>();
|
||||
bool use_lightmode = config["ncnn"]["use_lightmode"].as<bool>();
|
||||
auto input_name = config["ncnn"]["input_name"].as<std::string>();
|
||||
auto output_name = config["ncnn"]["output_name"].as<std::string>();
|
||||
int device_id = config["ncnn"]["device_id"].as<int>();
|
||||
wust_vl::ml_net::NCNNNet::Params params;
|
||||
params.model_path_param = model_path_param;
|
||||
params.model_path_bin = model_path_bin;
|
||||
params.input_name = input_name;
|
||||
params.output_name = output_name;
|
||||
params.use_vulkan = use_gpu;
|
||||
params.device_id = device_id;
|
||||
params.use_light_mode = use_lightmode;
|
||||
params.cpu_threads = cpu_threads;
|
||||
ncnn_net_ = std::make_unique<wust_vl::ml_net::NCNNNet>();
|
||||
ncnn_net_->init(params);
|
||||
}
|
||||
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
return std::make_unique<ArmorDetectorNCNN>(config, use_armor_detect_common);
|
||||
}
|
||||
~Impl() {
|
||||
armor_detect_common_.reset();
|
||||
ncnn_net_.reset();
|
||||
}
|
||||
cv::Mat ncnnMatToCvMat(const ncnn::Mat& m) {
|
||||
cv::Mat img(m.h, m.w, CV_8UC3);
|
||||
m.to_pixels(img.data, ncnn::Mat::PIXEL_RGB2BGR);
|
||||
|
||||
return img;
|
||||
}
|
||||
void setCallback(DetectorCallback callback) {
|
||||
infer_callback_ = callback;
|
||||
}
|
||||
static ncnn::Mat letterbox_to_ncnn(
|
||||
const cv::Mat& img,
|
||||
Eigen::Matrix3f& transform_matrix,
|
||||
int out_w,
|
||||
int out_h,
|
||||
float norm,
|
||||
bool swap_rb = true
|
||||
) {
|
||||
const int img_w = img.cols;
|
||||
const int img_h = img.rows;
|
||||
|
||||
float scale = std::min(out_w * 1.0f / img_w, out_h * 1.0f / img_h);
|
||||
int resize_w = static_cast<int>(round(img_w * scale));
|
||||
int resize_h = static_cast<int>(round(img_h * scale));
|
||||
|
||||
int pad_w = out_w - resize_w;
|
||||
int pad_h = out_h - resize_h;
|
||||
int pad_left = static_cast<int>(round(pad_w / 2.0f - 0.1f));
|
||||
int pad_top = static_cast<int>(round(pad_h / 2.0f - 0.1f));
|
||||
|
||||
transform_matrix << 1.0f / scale, 0, -pad_left / scale, 0, 1.0f / scale,
|
||||
-pad_top / scale, 0, 0, 1;
|
||||
ncnn::Mat out;
|
||||
if (swap_rb) {
|
||||
out = ncnn::Mat::from_pixels_resize(
|
||||
img.data,
|
||||
ncnn::Mat::PIXEL_BGR2RGB,
|
||||
img_w,
|
||||
img_h,
|
||||
resize_w,
|
||||
resize_h
|
||||
);
|
||||
} else {
|
||||
out = ncnn::Mat::from_pixels_resize(
|
||||
img.data,
|
||||
ncnn::Mat::PIXEL_RGB,
|
||||
img_w,
|
||||
img_h,
|
||||
resize_w,
|
||||
resize_h
|
||||
);
|
||||
}
|
||||
|
||||
int pad_right = out_w - resize_w - pad_left;
|
||||
int pad_bottom = out_h - resize_h - pad_top;
|
||||
|
||||
ncnn::Mat padded;
|
||||
ncnn::copy_make_border(
|
||||
out,
|
||||
padded,
|
||||
pad_top,
|
||||
pad_bottom,
|
||||
pad_left,
|
||||
pad_right,
|
||||
ncnn::BORDER_CONSTANT,
|
||||
114.f
|
||||
);
|
||||
|
||||
std::array<float, 3> mean_vals;
|
||||
std::array<float, 3> norm_vals;
|
||||
|
||||
mean_vals = { 0.f, 0.f, 0.f };
|
||||
norm_vals = { norm, norm, norm };
|
||||
|
||||
padded.substract_mean_normalize(mean_vals.data(), norm_vals.data());
|
||||
|
||||
return padded;
|
||||
}
|
||||
void
|
||||
processCallback(const CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
|
||||
// Eigen::Matrix3f transform_matrix;
|
||||
// cv::Mat resized_img = letterbox(frame.src_img, transform_matrix);
|
||||
// ncnn::Mat in =
|
||||
// ncnn::Mat::from_pixels(resized_img.data, ncnn::Mat::PIXEL_BGR2RGB, INPUT_W, INPUT_H);
|
||||
Eigen::Matrix3f transform_matrix;
|
||||
auto roi = frame.img_frame.src_img(frame.expanded);
|
||||
const bool swap_rb = armor_infer_->inputRGB()
|
||||
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
|
||||
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
|
||||
ncnn::Mat in = letterbox_to_ncnn(
|
||||
roi.clone(),
|
||||
transform_matrix,
|
||||
armor_infer_->inputW(),
|
||||
armor_infer_->inputH(),
|
||||
scale,
|
||||
swap_rb
|
||||
);
|
||||
cv::Mat resized_img = ncnnMatToCvMat(in);
|
||||
|
||||
auto out = ncnn_net_->infer(in);
|
||||
|
||||
cv::Mat output_buffer(out.h, out.w, CV_32F, out.data);
|
||||
|
||||
// Parse YOLO output
|
||||
auto objs_result = armor_infer_->postProcess(output_buffer);
|
||||
std::vector<ArmorObject> armors;
|
||||
if (armor_detect_common_) {
|
||||
armors = armor_detect_common_->detectNet(
|
||||
resized_img,
|
||||
objs_result,
|
||||
transform_matrix,
|
||||
frame.detect_color,
|
||||
target_number
|
||||
);
|
||||
// Call callback function
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
for (auto obj: objs_result) {
|
||||
auto detect_color = frame.detect_color;
|
||||
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
|
||||
continue;
|
||||
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
|
||||
continue;
|
||||
}
|
||||
obj.transform(transform_matrix);
|
||||
armors.push_back(obj);
|
||||
}
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
|
||||
frame.id = current_id_++;
|
||||
processCallback(frame, target_number);
|
||||
}
|
||||
|
||||
private:
|
||||
DetectorCallback infer_callback_;
|
||||
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
|
||||
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
|
||||
int current_id_ = 0;
|
||||
std::unique_ptr<wust_vl::ml_net::NCNNNet> ncnn_net_;
|
||||
};
|
||||
ArmorDetectorNCNN::ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
|
||||
}
|
||||
ArmorDetectorNCNN::~ArmorDetectorNCNN() {
|
||||
_impl.reset();
|
||||
}
|
||||
void ArmorDetectorNCNN::setCallback(DetectorCallback callback) {
|
||||
_impl->setCallback(callback);
|
||||
}
|
||||
void ArmorDetectorNCNN::pushInput(
|
||||
CommonFrame& frame,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) {
|
||||
_impl->pushInput(frame, target_number);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
#endif
|
||||
@@ -0,0 +1,36 @@
|
||||
// Copyright 2025 XiaoJian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorDetectorNCNN: public ArmorDetectorBase {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorDetectorNCNN>;
|
||||
explicit ArmorDetectorNCNN(const YAML::Node& config, bool use_armor_detect_common);
|
||||
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
return std::make_unique<ArmorDetectorNCNN>(config, use_armor_detect_common);
|
||||
}
|
||||
~ArmorDetectorNCNN();
|
||||
void setCallback(DetectorCallback callback) override;
|
||||
void
|
||||
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class NumberClassifierBase {
|
||||
public:
|
||||
virtual ~NumberClassifierBase() = default;
|
||||
virtual bool classifyNumber(ArmorObject& armor) = 0;
|
||||
virtual void initNumberClassifier() = 0;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
#include "number_classifier.hpp"
|
||||
#ifdef USE_TRT
|
||||
#include "number_classifier_trt.hpp"
|
||||
#endif
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class NumberClassifierFactory {
|
||||
public:
|
||||
static std::unique_ptr<NumberClassifierBase> createNumberClassifier(
|
||||
const std::string& backend,
|
||||
const std::string& classify_model_path,
|
||||
const std::string& classify_label_path
|
||||
) {
|
||||
#if defined(USE_TRT)
|
||||
if (backend == "tensorrt") {
|
||||
return std::make_unique<NumberClassifierTRT>(
|
||||
classify_model_path,
|
||||
classify_label_path
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (backend == "opencv") {
|
||||
return std::make_unique<NumberClassifier>(classify_model_path, classify_label_path);
|
||||
}
|
||||
|
||||
throw std::runtime_error(
|
||||
"Unsupported number classifier backend (or not compiled): " + backend
|
||||
);
|
||||
}
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,116 @@
|
||||
// Copyright Chen Jun 2023. Licensed under the MIT License.
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "number_classifier.hpp"
|
||||
#include <wust_vl/common/utils/logger.hpp>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
NumberClassifier::NumberClassifier(
|
||||
const std::string& classify_model_path,
|
||||
const std::string& classify_label_path
|
||||
):
|
||||
classify_model_path_(classify_model_path),
|
||||
classify_label_path_(classify_label_path) {
|
||||
initNumberClassifier();
|
||||
}
|
||||
void NumberClassifier::initNumberClassifier() {
|
||||
const std::string model_path = classify_model_path_;
|
||||
std::unique_ptr<cv::dnn::Net> number_net_ =
|
||||
std::make_unique<cv::dnn::Net>(cv::dnn::readNetFromONNX(model_path));
|
||||
|
||||
if (number_net_->empty()) {
|
||||
WUST_ERROR("number_classifier")
|
||||
<< "Failed to load number classifier model from " << model_path;
|
||||
std::exit(EXIT_FAILURE);
|
||||
} else {
|
||||
WUST_INFO("number_classifier")
|
||||
<< "Successfully loaded number classifier model from " << model_path;
|
||||
}
|
||||
|
||||
const std::string label_path = classify_label_path_;
|
||||
std::ifstream label_file(label_path);
|
||||
std::string line;
|
||||
|
||||
class_names_.clear();
|
||||
|
||||
while (std::getline(label_file, line)) {
|
||||
class_names_.push_back(line);
|
||||
}
|
||||
|
||||
if (class_names_.empty()) {
|
||||
WUST_ERROR("number_classifier") << "Failed to load labels from " << label_path;
|
||||
std::exit(EXIT_FAILURE);
|
||||
} else {
|
||||
WUST_INFO("number_classifier")
|
||||
<< "Successfully loaded " << class_names_.size() << " labels from " << label_path;
|
||||
}
|
||||
number_net_.reset();
|
||||
}
|
||||
bool NumberClassifier::classifyNumber(ArmorObject& armor) {
|
||||
static thread_local std::unique_ptr<cv::dnn::Net> thread_net;
|
||||
if (armor.number_img.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!thread_net) {
|
||||
thread_net =
|
||||
std::make_unique<cv::dnn::Net>(cv::dnn::readNetFromONNX(classify_model_path_));
|
||||
WUST_DEBUG("number_classifier") << "Loaded number classifier model for this thread";
|
||||
if (thread_net->empty()) {
|
||||
WUST_ERROR("number_classifier")
|
||||
<< "Failed to load thread-local number classifier model.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const cv::Mat image = armor.number_img;
|
||||
|
||||
cv::Mat blob;
|
||||
cv::dnn::blobFromImage(image, blob, 1.0 / 255.0);
|
||||
|
||||
thread_net->setInput(blob);
|
||||
cv::Mat outputs = thread_net->forward();
|
||||
double max_val;
|
||||
cv::minMaxLoc(outputs, nullptr, &max_val);
|
||||
|
||||
cv::Mat prob;
|
||||
cv::exp(outputs - max_val, prob);
|
||||
prob /= cv::sum(prob)[0];
|
||||
|
||||
double confidence;
|
||||
cv::Point class_id;
|
||||
cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id);
|
||||
|
||||
const int label_id = class_id.x;
|
||||
const double raw_conf = armor.confidence;
|
||||
armor.confidence = confidence;
|
||||
|
||||
static const std::map<int, ArmorNumber> label_to_armor_number = {
|
||||
{ 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 },
|
||||
{ 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST },
|
||||
{ 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE }
|
||||
};
|
||||
|
||||
if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) {
|
||||
armor.number = label_to_armor_number.at(label_id);
|
||||
|
||||
return true;
|
||||
} else {
|
||||
armor.confidence = raw_conf;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright Chen Jun 2023. Licensed under the MIT License.
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "base.hpp"
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class NumberClassifier: public NumberClassifierBase {
|
||||
public:
|
||||
NumberClassifier(
|
||||
const std::string& classify_model_path,
|
||||
const std::string& classify_label_path
|
||||
);
|
||||
void initNumberClassifier() override;
|
||||
bool classifyNumber(ArmorObject& armor) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> class_names_;
|
||||
std::string classify_model_path_;
|
||||
std::string classify_label_path_;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,118 @@
|
||||
// Copyright Chen Jun 2023. Licensed under the MIT License.
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#ifdef USE_TRT
|
||||
#include "number_classifier_trt.hpp"
|
||||
#include <wust_vl/common/utils/logger.hpp>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
NumberClassifierTRT::NumberClassifierTRT(
|
||||
const std::string& classify_model_path,
|
||||
const std::string& classify_label_path
|
||||
):
|
||||
classify_model_path_(classify_model_path),
|
||||
classify_label_path_(classify_label_path) {
|
||||
initNumberClassifier();
|
||||
}
|
||||
void NumberClassifierTRT::initNumberClassifier() {
|
||||
const std::string model_path = classify_model_path_;
|
||||
trt_net_ = std::make_unique<wust_vl::ml_net::TensorRTNet>();
|
||||
wust_vl::ml_net::TensorRTNet::Params trt_params;
|
||||
trt_params.model_path = model_path;
|
||||
trt_params.input_dims = nvinfer1::Dims4 { 1, 1, 20, 28 };
|
||||
trt_net_->init(trt_params);
|
||||
const auto input_output_dims = trt_net_->getInputOutputDims();
|
||||
input_dims_ = std::get<0>(input_output_dims);
|
||||
output_dims_ = std::get<1>(input_output_dims);
|
||||
|
||||
const std::string label_path = classify_label_path_;
|
||||
std::ifstream label_file(label_path);
|
||||
std::string line;
|
||||
|
||||
class_names_.clear();
|
||||
|
||||
while (std::getline(label_file, line)) {
|
||||
class_names_.push_back(line);
|
||||
}
|
||||
|
||||
if (class_names_.empty()) {
|
||||
WUST_ERROR("number_classifier_trt") << "Failed to load labels from " << label_path;
|
||||
std::exit(EXIT_FAILURE);
|
||||
} else {
|
||||
WUST_INFO("number_classifier_trt")
|
||||
<< "Successfully loaded " << class_names_.size() << " labels from " << label_path;
|
||||
}
|
||||
}
|
||||
bool NumberClassifierTRT::classifyNumber(ArmorObject& armor) {
|
||||
static thread_local std::unique_ptr<nvinfer1::IExecutionContext> ctx;
|
||||
if (armor.number_img.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctx) {
|
||||
auto c = trt_net_->getAContext();
|
||||
ctx = std::unique_ptr<nvinfer1::IExecutionContext>(c);
|
||||
WUST_DEBUG("number_classifier_trt") << "Loaded number classifier model for this thread";
|
||||
if (!ctx) {
|
||||
WUST_ERROR("number_classifier_trt")
|
||||
<< "Failed to load thread-local number classifier model.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const cv::Mat image = armor.number_img;
|
||||
cv::Mat blob;
|
||||
cv::dnn::blobFromImage(image, blob, 1.0 / 255.0, cv::Size(28, 20));
|
||||
trt_net_->input2Device(blob.ptr<float>());
|
||||
void* input_tensor_ptr = trt_net_->getInputTensorPtr();
|
||||
trt_net_->infer(input_tensor_ptr, ctx.get());
|
||||
|
||||
const float* out = static_cast<float*>(trt_net_->output2Host());
|
||||
|
||||
cv::Mat outputs(1, 9, CV_32F);
|
||||
std::memcpy(outputs.data, out, 9 * sizeof(float));
|
||||
|
||||
double max_val;
|
||||
cv::minMaxLoc(outputs, nullptr, &max_val);
|
||||
|
||||
cv::Mat prob;
|
||||
cv::exp(outputs - max_val, prob);
|
||||
prob /= cv::sum(prob)[0];
|
||||
|
||||
double confidence;
|
||||
cv::Point class_id;
|
||||
cv::minMaxLoc(prob, nullptr, &confidence, nullptr, &class_id);
|
||||
|
||||
const int label_id = class_id.x;
|
||||
const double raw_conf = armor.confidence;
|
||||
armor.confidence = confidence;
|
||||
|
||||
static const std::map<int, ArmorNumber> label_to_armor_number = {
|
||||
{ 0, ArmorNumber::NO1 }, { 1, ArmorNumber::NO2 }, { 2, ArmorNumber::NO3 },
|
||||
{ 3, ArmorNumber::NO4 }, { 4, ArmorNumber::NO5 }, { 5, ArmorNumber::OUTPOST },
|
||||
{ 6, ArmorNumber::SENTRY }, { 7, ArmorNumber::BASE }
|
||||
};
|
||||
|
||||
if (label_id < 8 && label_to_armor_number.find(label_id) != label_to_armor_number.end()) {
|
||||
armor.number = label_to_armor_number.at(label_id);
|
||||
|
||||
return true;
|
||||
} else {
|
||||
armor.confidence = raw_conf;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
#endif
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright Chen Jun 2023. Licensed under the MIT License.
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "base.hpp"
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
#include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class NumberClassifierTRT: public NumberClassifierBase {
|
||||
public:
|
||||
NumberClassifierTRT(
|
||||
const std::string& classify_model_path,
|
||||
const std::string& classify_label_path
|
||||
);
|
||||
void initNumberClassifier() override;
|
||||
bool classifyNumber(ArmorObject& armor) override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> class_names_;
|
||||
std::string classify_model_path_;
|
||||
std::string classify_label_path_;
|
||||
std::unique_ptr<wust_vl::ml_net::TensorRTNet> trt_net_;
|
||||
nvinfer1::Dims input_dims_;
|
||||
nvinfer1::Dims output_dims_;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,141 @@
|
||||
#ifdef USE_ORT
|
||||
|
||||
#include "armor_detector_onnxruntime.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
|
||||
#include "tasks/utils/utils.hpp"
|
||||
#include "wust_vl/ml_net/onnxruntime/onnxruntime_net.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
struct ArmorDetectorOnnxRuntime::Impl {
|
||||
public:
|
||||
Impl(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
if (use_armor_detect_common) {
|
||||
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
|
||||
}
|
||||
std::string model_type = config["onnxruntime"]["model_type"].as<std::string>();
|
||||
auto model = armor_infer::modeFromString(model_type);
|
||||
float conf_threshold = config["onnxruntime"]["conf_threshold"].as<float>();
|
||||
int top_k = config["onnxruntime"]["top_k"].as<int>();
|
||||
float nms_threshold = config["onnxruntime"]["nms_threshold"].as<float>();
|
||||
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
|
||||
model,
|
||||
conf_threshold,
|
||||
nms_threshold,
|
||||
top_k
|
||||
);
|
||||
std::string provider = config["onnxruntime"]["provider"].as<std::string>("CPU");
|
||||
provider_ = wust_vl::ml_net::string2OrtProvider(provider);
|
||||
onnxruntime_net_ = std::make_unique<wust_vl::ml_net::OnnxRuntimeNet>();
|
||||
wust_vl::ml_net::OnnxRuntimeNet::Params params;
|
||||
std::string model_path =
|
||||
utils::expandEnv(config["onnxruntime"]["model_path"].as<std::string>());
|
||||
params.model_path = model_path;
|
||||
params.provider = provider_;
|
||||
onnxruntime_net_->init(params);
|
||||
}
|
||||
|
||||
~Impl() {
|
||||
onnxruntime_net_.reset();
|
||||
armor_detect_common_.reset();
|
||||
}
|
||||
void setCallback(DetectorCallback callback) {
|
||||
infer_callback_ = callback;
|
||||
}
|
||||
void
|
||||
processCallback(const CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
|
||||
Eigen::Matrix3f transform_matrix;
|
||||
auto roi = frame.img_frame.src_img(frame.expanded);
|
||||
cv::Mat resized_img = utils::letterbox(
|
||||
roi,
|
||||
transform_matrix,
|
||||
armor_infer_->inputW(),
|
||||
armor_infer_->inputH()
|
||||
);
|
||||
const bool swap_rb = armor_infer_->inputRGB()
|
||||
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
|
||||
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
|
||||
cv::Mat blob = cv::dnn::blobFromImage(
|
||||
resized_img,
|
||||
scale,
|
||||
cv::Size(armor_infer_->inputW(), armor_infer_->inputH()),
|
||||
cv::Scalar(0, 0, 0),
|
||||
swap_rb
|
||||
);
|
||||
|
||||
auto output_data = onnxruntime_net_->infer(blob.ptr<float>(), blob.total());
|
||||
|
||||
auto output_shape = onnxruntime_net_->getOutputShape();
|
||||
int rows = static_cast<int>(output_shape[1]);
|
||||
int cols = static_cast<int>(output_shape[2]);
|
||||
cv::Mat output_buffer(rows, cols, CV_32F, output_data);
|
||||
|
||||
// Parsed variable
|
||||
std::vector<ArmorObject> objs_result;
|
||||
objs_result = armor_infer_->postProcess(output_buffer);
|
||||
std::vector<ArmorObject> armors;
|
||||
if (armor_detect_common_) {
|
||||
armors = armor_detect_common_->detectNet(
|
||||
resized_img,
|
||||
objs_result,
|
||||
transform_matrix,
|
||||
frame.detect_color,
|
||||
target_number
|
||||
);
|
||||
// Call callback function
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
for (auto obj: objs_result) {
|
||||
auto detect_color = frame.detect_color;
|
||||
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
|
||||
continue;
|
||||
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
|
||||
continue;
|
||||
}
|
||||
obj.transform(transform_matrix);
|
||||
armors.push_back(obj);
|
||||
}
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
|
||||
frame.id = current_id_++;
|
||||
processCallback(frame, target_number);
|
||||
}
|
||||
|
||||
wust_vl::ml_net::OrtProvider provider_ = wust_vl::ml_net::OrtProvider::CPU;
|
||||
DetectorCallback infer_callback_;
|
||||
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
|
||||
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
|
||||
int current_id_ = 0;
|
||||
std::unique_ptr<wust_vl::ml_net::OnnxRuntimeNet> onnxruntime_net_;
|
||||
};
|
||||
ArmorDetectorOnnxRuntime::ArmorDetectorOnnxRuntime(
|
||||
const YAML::Node& config,
|
||||
bool use_armor_detect_common
|
||||
) {
|
||||
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
|
||||
}
|
||||
ArmorDetectorOnnxRuntime::~ArmorDetectorOnnxRuntime() {
|
||||
_impl.reset();
|
||||
}
|
||||
void ArmorDetectorOnnxRuntime::setCallback(DetectorCallback callback) {
|
||||
_impl->setCallback(callback);
|
||||
}
|
||||
void ArmorDetectorOnnxRuntime::pushInput(
|
||||
CommonFrame& frame,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) {
|
||||
_impl->pushInput(frame, target_number);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
#endif
|
||||
@@ -0,0 +1,36 @@
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorDetectorOnnxRuntime: public ArmorDetectorBase {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorDetectorOnnxRuntime>;
|
||||
|
||||
explicit ArmorDetectorOnnxRuntime(const YAML::Node& config, bool use_armor_detect_common);
|
||||
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
return std::make_unique<ArmorDetectorOnnxRuntime>(config, use_armor_detect_common);
|
||||
}
|
||||
~ArmorDetectorOnnxRuntime();
|
||||
void
|
||||
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
|
||||
void setCallback(DetectorCallback callback) override;
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,382 @@
|
||||
// Copyright Chen Jun 2023. Licensed under the MIT License.
|
||||
//
|
||||
// Additional modifications and features by Chengfu Zou, Labor. Licensed under
|
||||
// Apache License 2.0.
|
||||
//
|
||||
// Copyright (C) FYT Vision Group. All rights reserved.
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tasks/auto_aim/armor_detect/opencv/armor_detector_opencv.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/number_classifier/number_classifier.hpp"
|
||||
#include "tasks/utils/utils.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
struct ArmorDetectorOpenCV::Impl {
|
||||
public:
|
||||
Impl(const YAML::Node& config) {
|
||||
auto classify_model_path =
|
||||
utils::expandEnv(config["classify"]["model_path"].as<std::string>());
|
||||
auto classify_label_path =
|
||||
utils::expandEnv(config["classify"]["label_path"].as<std::string>());
|
||||
double classify_threshold = config["classify"]["threshold"].as<double>();
|
||||
number_classifier_ =
|
||||
std::make_unique<NumberClassifier>(classify_model_path, classify_label_path);
|
||||
light_params_.load(config["light"]);
|
||||
armor_params_.load(config["armor"]);
|
||||
}
|
||||
|
||||
std::vector<ArmorObject> detect(
|
||||
const cv::Mat& input,
|
||||
int detect_color,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) noexcept {
|
||||
if (input.empty())
|
||||
return {};
|
||||
|
||||
std::vector<Light> lights_;
|
||||
|
||||
cv::Mat binary_img, gray_img;
|
||||
std::tie(binary_img, gray_img) = preprocessImage(input);
|
||||
lights_ = findLights(input, binary_img);
|
||||
std::vector<ArmorObject> armors = matchLights(lights_, detect_color);
|
||||
std::vector<ArmorObject> valid_armors;
|
||||
|
||||
for (auto& armor: armors) {
|
||||
try {
|
||||
armor.number_img = extractNumber(gray_img, armor);
|
||||
if (armor.number_img.empty())
|
||||
continue;
|
||||
|
||||
if (!number_classifier_->classifyNumber(armor))
|
||||
continue;
|
||||
if (target_number.has_value()) {
|
||||
if (!isSameTarget(target_number.value(), armor.number)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (armor.confidence < classifier_threshold_)
|
||||
continue;
|
||||
|
||||
if (armor.number != ArmorNumber::NO1 && armor.number != ArmorNumber::BASE
|
||||
&& armor.type == ArmorType::LARGE)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
valid_armors.push_back(armor);
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "[detect] Exception: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
return valid_armors;
|
||||
}
|
||||
|
||||
std::tuple<cv::Mat, cv::Mat> preprocessImage(const cv::Mat& img) noexcept {
|
||||
cv::Mat gray_img;
|
||||
cv::Mat binary_img;
|
||||
|
||||
if (img.empty()) {
|
||||
return { binary_img, gray_img }; // 空图直接返回空
|
||||
}
|
||||
|
||||
if (img.channels() == 3) {
|
||||
cv::cvtColor(img, gray_img, cv::COLOR_RGB2GRAY);
|
||||
} else if (img.channels() == 1) {
|
||||
cv::cvtColor(img, gray_img, cv::COLOR_BayerRG2GRAY);
|
||||
} else {
|
||||
return { binary_img, gray_img };
|
||||
}
|
||||
|
||||
cv::threshold(gray_img, binary_img, light_params_.binary_thres, 255, cv::THRESH_BINARY);
|
||||
|
||||
return { binary_img, gray_img };
|
||||
}
|
||||
|
||||
std::vector<Light> findLights(const cv::Mat& img, const cv::Mat& binary_img) noexcept {
|
||||
std::vector<std::vector<cv::Point>> contours;
|
||||
contours.reserve(64);
|
||||
cv::findContours(binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
|
||||
cv::Mat color_img;
|
||||
if (img.channels() == 3) {
|
||||
color_img = img;
|
||||
} else if (img.channels() == 1) {
|
||||
cv::cvtColor(img, color_img, cv::COLOR_BayerRG2BGR);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
std::vector<Light> lights;
|
||||
lights.reserve(contours.size());
|
||||
|
||||
for (const auto& contour: contours) {
|
||||
const int n = static_cast<int>(contour.size());
|
||||
if (n < 6)
|
||||
continue;
|
||||
Light light(contour);
|
||||
if (!isLight(light))
|
||||
continue;
|
||||
int sum_r = 0;
|
||||
int sum_b = 0;
|
||||
for (const auto& pt: contour) {
|
||||
const cv::Vec3b& pix = color_img.at<cv::Vec3b>(pt);
|
||||
sum_r += pix[0];
|
||||
sum_b += pix[2];
|
||||
}
|
||||
|
||||
const int avg_diff = std::abs(sum_r - sum_b) / n;
|
||||
if (avg_diff <= light_params_.color_diff_thresh)
|
||||
continue;
|
||||
|
||||
light.color = (sum_r > sum_b) ? 0 : 1;
|
||||
lights.emplace_back(std::move(light));
|
||||
}
|
||||
|
||||
std::sort(lights.begin(), lights.end(), [](const Light& a, const Light& b) {
|
||||
return a.center.x < b.center.x;
|
||||
});
|
||||
|
||||
return lights;
|
||||
}
|
||||
|
||||
bool isLight(const Light& light) noexcept {
|
||||
// width / length 比例
|
||||
const float ratio = light.width / light.length;
|
||||
|
||||
if (ratio <= light_params_.min_ratio || ratio >= light_params_.max_ratio)
|
||||
return false;
|
||||
|
||||
if (light.tilt_angle >= light_params_.max_angle)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<ArmorObject>
|
||||
matchLights(const std::vector<Light>& lights, int detect_color) noexcept {
|
||||
const int n = static_cast<int>(lights.size());
|
||||
std::vector<ArmorObject> armors;
|
||||
armors.reserve(n);
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const Light& l1 = lights[i];
|
||||
if (l1.color != detect_color)
|
||||
continue;
|
||||
|
||||
const float max_dx = l1.length * armor_params_.max_large_center_distance;
|
||||
|
||||
for (int j = i + 1; j < n; ++j) {
|
||||
const Light& l2 = lights[j];
|
||||
|
||||
if (l2.color != detect_color)
|
||||
continue;
|
||||
|
||||
const float dx = l2.center.x - l1.center.x;
|
||||
if (dx > max_dx)
|
||||
break;
|
||||
|
||||
ArmorType type = isArmor(l1, l2);
|
||||
if (type == ArmorType::INVALID)
|
||||
continue;
|
||||
|
||||
if (containLight(i, j, lights))
|
||||
continue;
|
||||
|
||||
ArmorObject armor(l1, l2);
|
||||
armor.type = type;
|
||||
armor.color = (detect_color == 0) ? ArmorColor::RED : ArmorColor::BLUE;
|
||||
|
||||
armors.emplace_back(std::move(armor));
|
||||
}
|
||||
}
|
||||
|
||||
return armors;
|
||||
}
|
||||
|
||||
// Check if there is another light in the boundingRect formed by the 2 lights
|
||||
bool containLight(const int i, const int j, const std::vector<Light>& lights) noexcept {
|
||||
const Light& l1 = lights[i];
|
||||
const Light& l2 = lights[j];
|
||||
|
||||
float min_x = std::min({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x });
|
||||
float max_x = std::max({ l1.top.x, l1.bottom.x, l2.top.x, l2.bottom.x });
|
||||
float min_y = std::min({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y });
|
||||
float max_y = std::max({ l1.top.y, l1.bottom.y, l2.top.y, l2.bottom.y });
|
||||
|
||||
const float avg_len = 0.5f * (l1.length + l2.length);
|
||||
const float avg_wid = 0.5f * (l1.width + l2.width);
|
||||
|
||||
for (int k = i + 1; k < j; ++k) {
|
||||
const Light& t = lights[k];
|
||||
|
||||
if (t.width > 2.0f * avg_wid)
|
||||
continue;
|
||||
if (t.length < 0.5f * avg_len)
|
||||
continue;
|
||||
|
||||
const cv::Point2f& c = t.center;
|
||||
if (c.x >= min_x && c.x <= max_x && c.y >= min_y && c.y <= max_y) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
ArmorType isArmor(const Light& l1, const Light& l2) noexcept {
|
||||
const float len1 = l1.length;
|
||||
const float len2 = l2.length;
|
||||
if (len1 <= 1e-3f || len2 <= 1e-3f)
|
||||
return ArmorType::INVALID;
|
||||
|
||||
const float min_len = (len1 < len2) ? len1 : len2;
|
||||
const float max_len = (len1 < len2) ? len2 : len1;
|
||||
if (min_len / max_len <= armor_params_.min_light_ratio)
|
||||
return ArmorType::INVALID;
|
||||
const cv::Point2f d = l1.center - l2.center;
|
||||
const float dist2 = d.dot(d);
|
||||
|
||||
const float avg_len = 0.5f * (len1 + len2);
|
||||
|
||||
const float min_small = armor_params_.min_small_center_distance * avg_len;
|
||||
const float max_small = armor_params_.max_small_center_distance * avg_len;
|
||||
const float min_large = armor_params_.min_large_center_distance * avg_len;
|
||||
const float max_large = armor_params_.max_large_center_distance * avg_len;
|
||||
|
||||
const float min_small2 = min_small * min_small;
|
||||
const float max_small2 = max_small * max_small;
|
||||
const float min_large2 = min_large * min_large;
|
||||
const float max_large2 = max_large * max_large;
|
||||
|
||||
const bool small_ok = (dist2 >= min_small2 && dist2 < max_small2);
|
||||
const bool large_ok = (dist2 >= min_large2 && dist2 < max_large2);
|
||||
|
||||
if (!(small_ok || large_ok))
|
||||
return ArmorType::INVALID;
|
||||
|
||||
static const float tan_max_angle = std::tan(armor_params_.max_angle * CV_PI / 180.0f);
|
||||
|
||||
if (std::abs(d.y) >= std::abs(d.x) * tan_max_angle)
|
||||
return ArmorType::INVALID;
|
||||
|
||||
float delta_angle = std::fabs(l1.angle - l2.angle);
|
||||
if (delta_angle > 90.0f)
|
||||
delta_angle = 180.0f - delta_angle;
|
||||
|
||||
if (delta_angle >= light_params_.max_angle_diff)
|
||||
return ArmorType::INVALID;
|
||||
|
||||
return large_ok ? ArmorType::LARGE : ArmorType::SMALL;
|
||||
}
|
||||
|
||||
cv::Mat extractNumber(const cv::Mat& src, const ArmorObject& armor) const noexcept {
|
||||
constexpr int light_length = 12;
|
||||
constexpr int warp_height = 28;
|
||||
constexpr int small_armor_width = 32;
|
||||
constexpr int large_armor_width = 54;
|
||||
const cv::Size roi_size(20, 28);
|
||||
|
||||
cv::Point2f src_pts[4] = { armor.lights[0].bottom,
|
||||
armor.lights[0].top,
|
||||
armor.lights[1].top,
|
||||
armor.lights[1].bottom };
|
||||
|
||||
const int warp_width =
|
||||
(armor.type == ArmorType::SMALL) ? small_armor_width : large_armor_width;
|
||||
|
||||
const int top_y = (warp_height - light_length) / 2 - 1;
|
||||
const int bottom_y = top_y + light_length;
|
||||
|
||||
cv::Point2f dst_pts[4] = {
|
||||
{ 0.f, static_cast<float>(bottom_y) },
|
||||
{ 0.f, static_cast<float>(top_y) },
|
||||
{ static_cast<float>(warp_width - 1), static_cast<float>(top_y) },
|
||||
{ static_cast<float>(warp_width - 1), static_cast<float>(bottom_y) }
|
||||
};
|
||||
|
||||
cv::Mat warp_mat = cv::getPerspectiveTransform(src_pts, dst_pts);
|
||||
|
||||
cv::Mat warped;
|
||||
cv::warpPerspective(
|
||||
src,
|
||||
warped,
|
||||
warp_mat,
|
||||
cv::Size(warp_width, warp_height),
|
||||
cv::INTER_LINEAR,
|
||||
cv::BORDER_CONSTANT,
|
||||
0
|
||||
);
|
||||
|
||||
const int roi_x = (warp_width - roi_size.width) >> 1;
|
||||
if (roi_x < 0 || roi_x + roi_size.width > warp_width)
|
||||
return cv::Mat();
|
||||
|
||||
cv::Mat number = warped(cv::Rect(roi_x, 0, roi_size.width, roi_size.height));
|
||||
|
||||
cv::threshold(number, number, 0, 255, cv::THRESH_BINARY | cv::THRESH_OTSU);
|
||||
|
||||
return number;
|
||||
}
|
||||
|
||||
void setCallback(DetectorCallback callback) {
|
||||
this->infer_callback_ = callback;
|
||||
}
|
||||
void toPts(ArmorObject& armor) {
|
||||
if (armor.lights.size() != 2) {
|
||||
armor.is_ok = false;
|
||||
return;
|
||||
}
|
||||
armor.pts[0] = armor.lights[0].top;
|
||||
armor.pts[1] = armor.lights[0].bottom;
|
||||
armor.pts[2] = armor.lights[1].bottom;
|
||||
armor.pts[3] = armor.lights[1].top;
|
||||
}
|
||||
|
||||
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
|
||||
frame.id = current_id_++;
|
||||
std::vector<ArmorObject> objs_result;
|
||||
auto roi = frame.img_frame.src_img(frame.expanded);
|
||||
|
||||
objs_result = detect(roi, frame.detect_color, target_number);
|
||||
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(objs_result, frame);
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
LightParams light_params_;
|
||||
ArmorParams armor_params_;
|
||||
double classifier_threshold_ = 0.5;
|
||||
std::unique_ptr<NumberClassifier> number_classifier_;
|
||||
DetectorCallback infer_callback_;
|
||||
int current_id_ = 0;
|
||||
};
|
||||
ArmorDetectorOpenCV::ArmorDetectorOpenCV(const YAML::Node& config) {
|
||||
_impl = std::make_unique<Impl>(config);
|
||||
}
|
||||
ArmorDetectorOpenCV::~ArmorDetectorOpenCV() {
|
||||
_impl.reset();
|
||||
}
|
||||
void ArmorDetectorOpenCV::setCallback(DetectorCallback callback) {
|
||||
_impl->setCallback(callback);
|
||||
}
|
||||
void ArmorDetectorOpenCV::pushInput(
|
||||
CommonFrame& frame,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) {
|
||||
_impl->pushInput(frame, target_number);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright Chen Jun 2023. Licensed under the MIT License.
|
||||
//
|
||||
// Additional modifications and features by Chengfu Zou, Labor. Licensed under
|
||||
// Apache License 2.0.
|
||||
//
|
||||
// Copyright (C) FYT Vision Group. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
|
||||
#include "tasks/type_common.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorDetectorOpenCV: public ArmorDetectorBase {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorDetectorOpenCV>;
|
||||
explicit ArmorDetectorOpenCV(const YAML::Node& config);
|
||||
static Ptr create(const YAML::Node& config) {
|
||||
return std::make_unique<ArmorDetectorOpenCV>(config);
|
||||
}
|
||||
~ArmorDetectorOpenCV();
|
||||
void
|
||||
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
|
||||
void setCallback(DetectorCallback callback) override;
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,206 @@
|
||||
// Copyright 2025 Zikang Xie
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#ifdef USE_OPENVINO
|
||||
#include "tasks/auto_aim/armor_detect/openvino/armor_detector_openvino.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
|
||||
#include "tasks/utils/utils.hpp"
|
||||
#include "wust_vl/ml_net/openvino/openvino_net.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
struct ArmorDetectorOpenVino::Impl {
|
||||
public:
|
||||
Impl(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
if (use_armor_detect_common) {
|
||||
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
|
||||
}
|
||||
std::string model_type = config["openvino"]["model_type"].as<std::string>();
|
||||
auto model = armor_infer::modeFromString(model_type);
|
||||
float conf_threshold = config["openvino"]["conf_threshold"].as<float>();
|
||||
int top_k = config["openvino"]["top_k"].as<int>();
|
||||
float nms_threshold = config["openvino"]["nms_threshold"].as<float>();
|
||||
bool use_throughputmode = config["openvino"]["use_throughputmode"].as<bool>();
|
||||
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
|
||||
model,
|
||||
conf_threshold,
|
||||
nms_threshold,
|
||||
top_k
|
||||
);
|
||||
std::string model_path =
|
||||
utils::expandEnv(config["openvino"]["model_path"].as<std::string>());
|
||||
|
||||
auto device_name = config["openvino"]["device_name"].as<std::string>();
|
||||
ov_params_.model_path = model_path;
|
||||
ov_params_.device_name = device_name;
|
||||
ov_params_.mode = use_throughputmode ? ov::hint::PerformanceMode::THROUGHPUT
|
||||
: ov::hint::PerformanceMode::LATENCY;
|
||||
initOpenVINO();
|
||||
}
|
||||
void
|
||||
initOpenVINO(wust_vl::video::PixelFormat pixel_format = wust_vl::video::PixelFormat::BGR) {
|
||||
openvino_net_.reset();
|
||||
openvino_net_ = std::make_unique<wust_vl::ml_net::OpenvinoNet>();
|
||||
const auto ppp_init_fun = [this, pixel_format](ov::preprocess::PrePostProcessor& ppp) {
|
||||
if (pixel_format == wust_vl::video::PixelFormat::RGB) {
|
||||
ppp.input()
|
||||
.tensor()
|
||||
.set_element_type(ov::element::u8)
|
||||
.set_layout("NHWC")
|
||||
.set_color_format(ov::preprocess::ColorFormat::RGB);
|
||||
} else {
|
||||
ppp.input()
|
||||
.tensor()
|
||||
.set_element_type(ov::element::u8)
|
||||
.set_layout("NHWC")
|
||||
.set_color_format(ov::preprocess::ColorFormat::BGR);
|
||||
}
|
||||
pixel_format_ = pixel_format;
|
||||
const bool RGB = armor_infer_->inputRGB();
|
||||
const float scale = armor_infer_->useNorm() ? 255.0f : 1.0f;
|
||||
if (RGB) {
|
||||
ppp.input()
|
||||
.preprocess()
|
||||
.convert_element_type(ov::element::f32)
|
||||
.convert_color(ov::preprocess::ColorFormat::RGB)
|
||||
.scale(scale);
|
||||
} else {
|
||||
ppp.input()
|
||||
.preprocess()
|
||||
.convert_element_type(ov::element::f32)
|
||||
.convert_color(ov::preprocess::ColorFormat::BGR)
|
||||
.scale(scale);
|
||||
}
|
||||
|
||||
ppp.input().model().set_layout("NCHW");
|
||||
|
||||
ppp.output().tensor().set_element_type(ov::element::f32);
|
||||
};
|
||||
|
||||
openvino_net_->init(ov_params_, ppp_init_fun);
|
||||
}
|
||||
|
||||
~Impl() {
|
||||
openvino_net_.reset();
|
||||
armor_detect_common_.reset();
|
||||
}
|
||||
|
||||
void setCallback(DetectorCallback callback) {
|
||||
infer_callback_ = callback;
|
||||
}
|
||||
bool processCallback(
|
||||
const CommonFrame& frame,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) const {
|
||||
const auto start = std::chrono::steady_clock::now();
|
||||
Eigen::Matrix3f transform_matrix;
|
||||
const auto roi = frame.img_frame.src_img(frame.expanded);
|
||||
cv::Mat resized_img = utils::letterbox(
|
||||
roi,
|
||||
transform_matrix,
|
||||
armor_infer_->inputW(),
|
||||
armor_infer_->inputH()
|
||||
);
|
||||
const auto input_info = openvino_net_->getInputInfo();
|
||||
const auto input_tensor =
|
||||
ov::Tensor(input_info.first, input_info.second, resized_img.data);
|
||||
const auto output = openvino_net_->infer_thread_local(input_tensor);
|
||||
|
||||
// Process output data
|
||||
const auto output_shape = output.get_shape();
|
||||
|
||||
const float* ptr = output.data<const float>();
|
||||
cv::Mat
|
||||
output_buffer(output_shape[1], output_shape[2], CV_32F, const_cast<float*>(ptr));
|
||||
|
||||
// Parsed variable
|
||||
auto objs_result = armor_infer_->postProcess(output_buffer);
|
||||
|
||||
std::vector<ArmorObject> armors;
|
||||
if (armor_detect_common_) {
|
||||
armors = armor_detect_common_->detectNet(
|
||||
resized_img,
|
||||
objs_result,
|
||||
transform_matrix,
|
||||
frame.detect_color,
|
||||
target_number
|
||||
);
|
||||
// Call callback function
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
for (auto obj: objs_result) {
|
||||
auto detect_color = frame.detect_color;
|
||||
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
|
||||
continue;
|
||||
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
|
||||
continue;
|
||||
}
|
||||
obj.transform(transform_matrix);
|
||||
armors.push_back(obj);
|
||||
}
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
|
||||
if (resetting_) {
|
||||
return;
|
||||
}
|
||||
frame.id = current_id_++;
|
||||
if (frame.img_frame.pixel_format != pixel_format_) {
|
||||
resetting_ = true;
|
||||
initOpenVINO(frame.img_frame.pixel_format);
|
||||
resetting_ = false;
|
||||
}
|
||||
processCallback(frame, target_number);
|
||||
}
|
||||
|
||||
private:
|
||||
wust_vl::video::PixelFormat pixel_format_ = wust_vl::video::PixelFormat::BGR;
|
||||
std::unique_ptr<wust_vl::ml_net::OpenvinoNet> openvino_net_;
|
||||
DetectorCallback infer_callback_;
|
||||
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
|
||||
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
|
||||
int current_id_ = 0;
|
||||
wust_vl::ml_net::OpenvinoNet::Params ov_params_;
|
||||
bool resetting_ = false;
|
||||
};
|
||||
ArmorDetectorOpenVino::ArmorDetectorOpenVino(
|
||||
const YAML::Node& config,
|
||||
bool use_armor_detect_common
|
||||
) {
|
||||
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
|
||||
}
|
||||
ArmorDetectorOpenVino::~ArmorDetectorOpenVino() {
|
||||
_impl.reset();
|
||||
}
|
||||
void ArmorDetectorOpenVino::setCallback(DetectorCallback callback) {
|
||||
_impl->setCallback(callback);
|
||||
}
|
||||
void ArmorDetectorOpenVino::pushInput(
|
||||
CommonFrame& frame,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) {
|
||||
_impl->pushInput(frame, target_number);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
#endif
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright 2023 Yunlong Feng
|
||||
// Copyright 2025 Lihan Chen
|
||||
// Copyright 2025 XiaoJian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorDetectorOpenVino: public ArmorDetectorBase {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorDetectorOpenVino>;
|
||||
explicit ArmorDetectorOpenVino(const YAML::Node& config, bool use_armor_detect_common);
|
||||
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
return std::make_unique<ArmorDetectorOpenVino>(config, use_armor_detect_common);
|
||||
}
|
||||
~ArmorDetectorOpenVino();
|
||||
void
|
||||
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
|
||||
|
||||
void setCallback(DetectorCallback callback) override;
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,320 @@
|
||||
// Copyright 2025 Zikang Xie
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#ifdef USE_TRT
|
||||
#include "tasks/auto_aim/armor_detect/tensorrt/armor_detector_tensorrt.hpp"
|
||||
#include "cuda_infer/armor_infer.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_common.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_infer.hpp"
|
||||
#include "tasks/utils/utils.hpp"
|
||||
#include "wust_vl/common/concurrency/adaptive_resource_pool.hpp"
|
||||
#include "wust_vl/common/utils/logger.hpp"
|
||||
#include "wust_vl/common/utils/timer.hpp"
|
||||
#include "wust_vl/ml_net/tensorrt/tensorrt_net.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
static constexpr int MAX_SRC_IMG_W = 1920;
|
||||
static constexpr int MAX_SRC_IMG_H = 1440;
|
||||
struct ArmorDetectorTrt::Impl {
|
||||
public:
|
||||
struct Infer {
|
||||
std::unique_ptr<nvinfer1::IExecutionContext> context;
|
||||
std::unique_ptr<armor_cuda_infer::CudaInfer> cuda_infer;
|
||||
};
|
||||
|
||||
Impl(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
if (use_armor_detect_common) {
|
||||
armor_detect_common_ = std::make_unique<ArmorDetectorCommon>(config);
|
||||
}
|
||||
const double conf_threshold = config["tensorrt"]["conf_threshold"].as<float>();
|
||||
const double nms_threshold = config["tensorrt"]["nms_threshold"].as<float>();
|
||||
const int top_k = config["tensorrt"]["top_k"].as<int>();
|
||||
const int max_infer_running = config["tensorrt"]["max_infer_running"].as<int>();
|
||||
const double min_free_mem_ratio = config["tensorrt"]["min_free_mem_ratio"].as<double>();
|
||||
use_cuda_pre_ = config["tensorrt"]["use_cuda_pre"].as<bool>();
|
||||
log_time_ = config["tensorrt"]["log_time"].as<bool>();
|
||||
const std::string model_type = config["tensorrt"]["model_type"].as<std::string>();
|
||||
const std::string model_path =
|
||||
utils::expandEnv(config["tensorrt"]["model_path"].as<std::string>());
|
||||
const int device_id = config["tensorrt"]["device_id"].as<int>();
|
||||
cudaSetDevice(device_id);
|
||||
const auto model = armor_infer::modeFromString(model_type);
|
||||
armor_infer_ = std::make_unique<armor_infer::ArmorInfer>(
|
||||
model,
|
||||
conf_threshold,
|
||||
nms_threshold,
|
||||
top_k
|
||||
);
|
||||
trt_net_ = std::make_unique<wust_vl::ml_net::TensorRTNet>();
|
||||
wust_vl::ml_net::TensorRTNet::Params trt_params;
|
||||
trt_params.model_path = model_path;
|
||||
trt_params.input_dims =
|
||||
nvinfer1::Dims4 { 1, 3, armor_infer_->inputH(), armor_infer_->inputW() };
|
||||
trt_net_->init(trt_params);
|
||||
const auto input_output_dims = trt_net_->getInputOutputDims();
|
||||
input_dims_ = std::get<0>(input_output_dims);
|
||||
output_dims_ = std::get<1>(input_output_dims);
|
||||
|
||||
wust_vl::common::concurrency::AdaptiveResourcePool<Infer>::Params pool_params;
|
||||
pool_params.resource_initializer = [&]() {
|
||||
std::vector<std::unique_ptr<Infer>> infers;
|
||||
for (int i = 0; i < max_infer_running; ++i) {
|
||||
auto infer = std::make_unique<Infer>();
|
||||
auto ctx = trt_net_->getAContext();
|
||||
infer->context = std::unique_ptr<nvinfer1::IExecutionContext>(ctx);
|
||||
if (use_cuda_pre_) {
|
||||
infer->cuda_infer = std::make_unique<armor_cuda_infer::CudaInfer>();
|
||||
infer->cuda_infer->init(
|
||||
MAX_SRC_IMG_W,
|
||||
MAX_SRC_IMG_H,
|
||||
armor_infer_->inputW(),
|
||||
armor_infer_->inputH()
|
||||
);
|
||||
}
|
||||
if (!infer->context) {
|
||||
WUST_ERROR("TRT") << "create infer failed, missing context"
|
||||
<< " index:" << i;
|
||||
continue;
|
||||
}
|
||||
if (use_cuda_pre_ && !infer->cuda_infer) {
|
||||
WUST_ERROR("TRT") << "create infer failed, missing cuda_infer"
|
||||
<< " index:" << i;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t free_mem, total_mem;
|
||||
cudaMemGetInfo(&free_mem, &total_mem);
|
||||
WUST_DEBUG("TRT") << "Free GPU memory:" << free_mem / 1024.0 / 1024.0 << "MB"
|
||||
<< "Total GPU memory:" << total_mem / 1024.0 / 1024.0 << "MB";
|
||||
double free_mem_ratio =
|
||||
static_cast<double>(free_mem) / static_cast<double>(total_mem);
|
||||
if (free_mem_ratio < min_free_mem_ratio && i > 0) {
|
||||
WUST_WARN("TRT") << "GPU memory is not enough!"
|
||||
<< "Free GPU memory:" << free_mem_ratio * 100 << "%";
|
||||
WUST_INFO("TRT") << "Cut remaining infer";
|
||||
break;
|
||||
}
|
||||
infers.emplace_back(std::move(infer));
|
||||
WUST_INFO("TRT") << "create execution context success"
|
||||
<< "index:" << i;
|
||||
}
|
||||
return infers;
|
||||
};
|
||||
auto release_func = [&](std::unique_ptr<Infer>& resource) {
|
||||
if (resource) {
|
||||
if (resource->cuda_infer) {
|
||||
resource->cuda_infer.reset();
|
||||
}
|
||||
}
|
||||
};
|
||||
auto restore_func = [&](size_t idx) -> std::unique_ptr<Infer> {
|
||||
auto infer = std::make_unique<Infer>();
|
||||
auto ctx = trt_net_->getAContext();
|
||||
infer->context = std::unique_ptr<nvinfer1::IExecutionContext>(ctx);
|
||||
if (use_cuda_pre_) {
|
||||
infer->cuda_infer = std::make_unique<armor_cuda_infer::CudaInfer>();
|
||||
infer->cuda_infer->init(
|
||||
MAX_SRC_IMG_W,
|
||||
MAX_SRC_IMG_H,
|
||||
armor_infer_->inputW(),
|
||||
armor_infer_->inputH()
|
||||
);
|
||||
}
|
||||
if (!infer->context) {
|
||||
WUST_ERROR("TRT") << "create infer failed, missing context";
|
||||
return nullptr;
|
||||
}
|
||||
if ((use_cuda_pre_) && !infer->cuda_infer) {
|
||||
WUST_ERROR("TRT") << "create infer failed, missing cuda_infer";
|
||||
return nullptr;
|
||||
}
|
||||
return infer;
|
||||
};
|
||||
pool_params.restore_func = restore_func;
|
||||
|
||||
pool_params.release_func = release_func;
|
||||
|
||||
pool_params.can_restore = [&](size_t active_count) { return false; };
|
||||
|
||||
pool_params.should_release = [&](size_t active_count) { return false; };
|
||||
|
||||
pool_params.logger = [](const std::string& msg) {
|
||||
WUST_INFO("ArmorDetectorTrt:infer pool") << msg;
|
||||
};
|
||||
infer_pool_ =
|
||||
std::make_unique<wust_vl::common::concurrency::AdaptiveResourcePool<Infer>>(
|
||||
pool_params
|
||||
);
|
||||
}
|
||||
|
||||
~Impl() {
|
||||
if (infer_pool_) {
|
||||
infer_pool_.reset();
|
||||
}
|
||||
trt_net_.reset();
|
||||
armor_detect_common_.reset();
|
||||
}
|
||||
|
||||
void setCallback(DetectorCallback callback) {
|
||||
infer_callback_ = callback;
|
||||
}
|
||||
struct Tag {};
|
||||
void processCallback(
|
||||
const CommonFrame& frame,
|
||||
Infer* infer,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) const {
|
||||
std::vector<ArmorObject> armors;
|
||||
const auto t0 = wust_vl::common::utils::time_utils::now();
|
||||
Eigen::Matrix3f transform_matrix;
|
||||
std::vector<ArmorObject> objs_result;
|
||||
void* input_tensor_ptr;
|
||||
const cv::Mat roi = frame.img_frame.src_img(frame.expanded);
|
||||
|
||||
cv::Mat resized_img;
|
||||
const float scale = armor_infer_->useNorm() ? 1.0f / 255.0f : 1.0f;
|
||||
const bool swap_rb = armor_infer_->inputRGB()
|
||||
!= (frame.img_frame.pixel_format == wust_vl::video::PixelFormat::RGB);
|
||||
if (infer->cuda_infer && use_cuda_pre_) {
|
||||
input_tensor_ptr =
|
||||
infer->cuda_infer
|
||||
->preprocess_pitched( //支持不连续内存,无需拷贝后输入可直接传roi的指针
|
||||
roi.data,
|
||||
roi.cols,
|
||||
roi.rows,
|
||||
roi.step,
|
||||
scale,
|
||||
swap_rb,
|
||||
transform_matrix,
|
||||
trt_net_->getStream()
|
||||
);
|
||||
resized_img = infer->cuda_infer->tensorToMat( //nchw_float_to_hwc_uchar
|
||||
static_cast<float*>(input_tensor_ptr),
|
||||
armor_infer_->inputW(),
|
||||
armor_infer_->inputH(),
|
||||
scale,
|
||||
trt_net_->getStream()
|
||||
);
|
||||
} else {
|
||||
resized_img = utils::letterbox(
|
||||
roi,
|
||||
transform_matrix,
|
||||
armor_infer_->inputW(),
|
||||
armor_infer_->inputH()
|
||||
);
|
||||
const cv::Mat blob = cv::dnn::blobFromImage(
|
||||
resized_img,
|
||||
scale,
|
||||
cv::Size(armor_infer_->inputW(), armor_infer_->inputH()),
|
||||
cv::Scalar(0, 0, 0),
|
||||
swap_rb
|
||||
);
|
||||
trt_net_->input2Device(blob.ptr<float>());
|
||||
input_tensor_ptr = trt_net_->getInputTensorPtr();
|
||||
}
|
||||
const auto t1 = wust_vl::common::utils::time_utils::now();
|
||||
if (infer->context && input_tensor_ptr) {
|
||||
trt_net_->infer(input_tensor_ptr, infer->context.get());
|
||||
}
|
||||
const auto t2 = wust_vl::common::utils::time_utils::now();
|
||||
const cv::Mat
|
||||
output_mat(output_dims_.d[1], output_dims_.d[2], CV_32F, trt_net_->output2Host());
|
||||
cudaStreamSynchronize(trt_net_->getStream());
|
||||
objs_result = armor_infer_->postProcess(output_mat);
|
||||
const auto t3 = wust_vl::common::utils::time_utils::now();
|
||||
if (log_time_) {
|
||||
WUST_INFO("TRT") << std::fixed << std::setprecision(3) << "pre "
|
||||
<< wust_vl::common::utils::time_utils::durationMs(t0, t1) << " "
|
||||
<< "infer "
|
||||
<< wust_vl::common::utils::time_utils::durationMs(t1, t2) << " "
|
||||
<< "post "
|
||||
<< wust_vl::common::utils::time_utils::durationMs(t2, t3) << " "
|
||||
<< "total "
|
||||
<< wust_vl::common::utils::time_utils::durationMs(t0, t3);
|
||||
}
|
||||
infer_pool_->release(infer);
|
||||
|
||||
if (armor_detect_common_) {
|
||||
armors = armor_detect_common_->detectNet(
|
||||
resized_img,
|
||||
objs_result,
|
||||
transform_matrix,
|
||||
frame.detect_color,
|
||||
target_number
|
||||
);
|
||||
// Call callback function
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
for (auto obj: objs_result) {
|
||||
auto detect_color = frame.detect_color;
|
||||
if (detect_color == 0 && obj.color == ArmorColor::BLUE) {
|
||||
continue;
|
||||
} else if (detect_color == 1 && obj.color == ArmorColor::RED) {
|
||||
continue;
|
||||
}
|
||||
obj.transform(transform_matrix);
|
||||
armors.push_back(obj);
|
||||
}
|
||||
if (this->infer_callback_) {
|
||||
this->infer_callback_(armors, frame);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) {
|
||||
if (infer_pool_) {
|
||||
auto infer_ptr = infer_pool_->acquire();
|
||||
if (infer_ptr != nullptr) {
|
||||
frame.id = current_id_++;
|
||||
this->processCallback(frame, infer_ptr, target_number);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_cuda_pre_;
|
||||
bool log_time_;
|
||||
nvinfer1::Dims input_dims_;
|
||||
nvinfer1::Dims output_dims_;
|
||||
DetectorCallback infer_callback_;
|
||||
std::unique_ptr<ArmorDetectorCommon> armor_detect_common_;
|
||||
std::unique_ptr<wust_vl::common::concurrency::AdaptiveResourcePool<Infer>> infer_pool_;
|
||||
std::unique_ptr<armor_infer::ArmorInfer> armor_infer_;
|
||||
int current_id_ = 0;
|
||||
std::unique_ptr<wust_vl::ml_net::TensorRTNet> trt_net_;
|
||||
};
|
||||
ArmorDetectorTrt::ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
_impl = std::make_unique<Impl>(config, use_armor_detect_common);
|
||||
}
|
||||
ArmorDetectorTrt::~ArmorDetectorTrt() {
|
||||
_impl.reset();
|
||||
}
|
||||
void ArmorDetectorTrt::setCallback(DetectorCallback callback) {
|
||||
_impl->setCallback(callback);
|
||||
}
|
||||
void ArmorDetectorTrt::pushInput(
|
||||
CommonFrame& frame,
|
||||
const std::optional<ArmorNumber>& target_number
|
||||
) {
|
||||
_impl->pushInput(frame, target_number);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
#endif
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright 2025 Zikang Xie
|
||||
// Copyright 2025 XiaoJian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_base.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorDetectorTrt: public ArmorDetectorBase {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorDetectorTrt>;
|
||||
|
||||
explicit ArmorDetectorTrt(const YAML::Node& config, bool use_armor_detect_common);
|
||||
static Ptr create(const YAML::Node& config, bool use_armor_detect_common) {
|
||||
return std::make_unique<ArmorDetectorTrt>(config, use_armor_detect_common);
|
||||
}
|
||||
~ArmorDetectorTrt();
|
||||
void
|
||||
pushInput(CommonFrame& frame, const std::optional<ArmorNumber>& target_number) override;
|
||||
|
||||
void setCallback(DetectorCallback callback) override;
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
442
wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.cpp
Normal file
442
wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.cpp
Normal file
@@ -0,0 +1,442 @@
|
||||
#include "armor_omni.hpp"
|
||||
#include "3rdparty/angles.h"
|
||||
#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp"
|
||||
#include "tasks/auto_aim/armor_tracker/target.hpp"
|
||||
#include "tasks/auto_aim/auto_aim_fsm.hpp"
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
#include "wust_vl/common/concurrency/ThreadPool.h"
|
||||
#include "wust_vl/common/utils/timer.hpp"
|
||||
#include "wust_vl/video/camera.hpp"
|
||||
// clang-format off
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp"
|
||||
// clang-format on
|
||||
#include "tasks/auto_aim/armor_tracker/trackerv3.hpp"
|
||||
#include "tasks/auto_aim/armor_where/armor_where.hpp"
|
||||
namespace wust_vision::auto_aim {
|
||||
|
||||
struct ArmorOmni::Impl {
|
||||
struct One {
|
||||
using Ptr = std::shared_ptr<One>;
|
||||
|
||||
One(int id) {
|
||||
self_id = id;
|
||||
total_score = 0;
|
||||
}
|
||||
|
||||
static Ptr create(int id) {
|
||||
return std::make_shared<One>(id);
|
||||
}
|
||||
|
||||
void load(
|
||||
const YAML::Node& config,
|
||||
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter
|
||||
) noexcept {
|
||||
auto yaw_in_big_yaw_deg = config["yaw_in_big_yaw_deg"].as<double>();
|
||||
yaw_in_big_yaw = yaw_in_big_yaw_deg / 180.0 * M_PI;
|
||||
camera = std::make_shared<wust_vl::video::Camera>();
|
||||
camera->init(config);
|
||||
std::string camera_info_path =
|
||||
utils::expandEnv(config["camera_info_path"].as<std::string>());
|
||||
YAML::Node config_camera_info = YAML::LoadFile(camera_info_path);
|
||||
std::vector<double> camera_k =
|
||||
config_camera_info["camera_matrix"]["data"].as<std::vector<double>>();
|
||||
std::vector<double> camera_d =
|
||||
config_camera_info["distortion_coefficients"]["data"].as<std::vector<double>>();
|
||||
|
||||
assert(camera_k.size() == 9);
|
||||
assert(camera_d.size() == 5);
|
||||
|
||||
cv::Mat K(3, 3, CV_64F);
|
||||
std::memcpy(K.data, camera_k.data(), 9 * sizeof(double));
|
||||
|
||||
cv::Mat D(1, 5, CV_64F);
|
||||
std::memcpy(D.data, camera_d.data(), 5 * sizeof(double));
|
||||
|
||||
camera_info = std::make_pair(K.clone(), D.clone());
|
||||
auto gobal_config = auto_aim_config_parameter->getConfig();
|
||||
armor_where = ArmorWhere::create(gobal_config["armor_where"], camera_info);
|
||||
tracker = Tracker::create(auto_aim_config_parameter);
|
||||
}
|
||||
|
||||
void start() noexcept {
|
||||
if (camera)
|
||||
camera->start();
|
||||
}
|
||||
|
||||
int self_id;
|
||||
double total_score;
|
||||
std::shared_ptr<wust_vl::video::Camera> camera;
|
||||
ArmorWhere::Ptr armor_where;
|
||||
Tracker::Ptr tracker;
|
||||
std::pair<cv::Mat, cv::Mat> camera_info;
|
||||
double yaw_in_big_yaw;
|
||||
Target target;
|
||||
};
|
||||
|
||||
struct Obj {
|
||||
ArmorObject armor;
|
||||
double score = 0;
|
||||
One::Ptr one;
|
||||
std::chrono::steady_clock::time_point timestamp;
|
||||
|
||||
Obj(const ArmorObject& a,
|
||||
double s,
|
||||
const One::Ptr& o,
|
||||
std::chrono::steady_clock::time_point ts):
|
||||
armor(a),
|
||||
score(s),
|
||||
one(o),
|
||||
timestamp(ts) {
|
||||
if (one)
|
||||
one->total_score += score;
|
||||
}
|
||||
|
||||
~Obj() {
|
||||
if (one)
|
||||
one->total_score -= score;
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr const char* _ML_CONFIG = "config/omni/detect_ml.yaml";
|
||||
static constexpr const char* _OPENCV_CONFIG = "config/omni/detect_opencv.yaml";
|
||||
|
||||
~Impl() {
|
||||
run_flag_ = false;
|
||||
}
|
||||
|
||||
Impl(bool detect_color_init, const Ctx& ctx) {
|
||||
ctx_ = ctx;
|
||||
detect_color_ = detect_color_init;
|
||||
|
||||
config_ = YAML::LoadFile(OMNI_CONFIG);
|
||||
auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create();
|
||||
auto_aim_config_parameter_->loadFromFile(OMNI_CONFIG);
|
||||
auto cameras = config_["cameras"].as<std::vector<std::string>>();
|
||||
|
||||
for (size_t i = 0; i < cameras.size(); ++i) {
|
||||
auto real_path = utils::expandEnv(cameras[i]);
|
||||
One::Ptr one = One::create(i);
|
||||
one->load(YAML::LoadFile(real_path), auto_aim_config_parameter_);
|
||||
ones_.emplace_back(one);
|
||||
}
|
||||
auto_aim_config_parameter_->reloadFromOldPath();
|
||||
fps_ = config_["fps"].as<int>(30);
|
||||
active_time_ = config_["active_time"].as<double>(0.5);
|
||||
max_infer_running_ = config_["max_infer_running"].as<int>(0);
|
||||
min_score_ = config_["min_score"].as<double>();
|
||||
const std::string armor_detect_backend =
|
||||
config_["armor_detect_backend"].as<std::string>("");
|
||||
|
||||
armor_detector_ = DetectorFactory::createArmorDetector(
|
||||
armor_detect_backend,
|
||||
false,
|
||||
_OPENCV_CONFIG,
|
||||
_ML_CONFIG
|
||||
);
|
||||
|
||||
armor_detector_->setCallback(std::bind(
|
||||
&ArmorOmni::Impl::ArmorDetectCallback,
|
||||
this,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2
|
||||
));
|
||||
|
||||
thread_pool_ =
|
||||
std::make_unique<wust_vl::common::concurrency::ThreadPool>(max_infer_running_);
|
||||
|
||||
timer_ = std::make_unique<wust_vl::common::utils::Timer>("omni");
|
||||
latency_averager_ = std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
|
||||
}
|
||||
|
||||
void start() noexcept {
|
||||
run_flag_ = true;
|
||||
|
||||
for (auto& one: ones_) {
|
||||
one->start();
|
||||
}
|
||||
|
||||
if (timer_) {
|
||||
const auto timercallback =
|
||||
std::bind(&ArmorOmni::Impl::timerCallback, this, std::placeholders::_1);
|
||||
|
||||
const double rate_hz = fps_;
|
||||
|
||||
timer_->start(rate_hz, timercallback);
|
||||
}
|
||||
}
|
||||
|
||||
int getOneId() const {
|
||||
static int one_id = 0;
|
||||
|
||||
int id = one_id;
|
||||
|
||||
one_id = (one_id + 1) % ones_.size();
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
void timerCallback(double dt_ms) noexcept {
|
||||
if (!run_flag_ || main_tracking_)
|
||||
return;
|
||||
|
||||
int one_id = getOneId();
|
||||
|
||||
auto& one = ones_[one_id];
|
||||
|
||||
auto frame = one->camera->readImage();
|
||||
|
||||
if (frame.src_img.empty())
|
||||
return;
|
||||
|
||||
CommonFrame common_frame;
|
||||
|
||||
common_frame.img_frame = frame;
|
||||
common_frame.id = one_id;
|
||||
common_frame.detect_color = detect_color_;
|
||||
common_frame.expanded = cv::Rect(0, 0, frame.src_img.cols, frame.src_img.rows);
|
||||
common_frame.offset = cv::Point2f(0, 0);
|
||||
common_frame.any_ctx = one;
|
||||
|
||||
detect(common_frame);
|
||||
}
|
||||
|
||||
void detect(CommonFrame& common_frame) {
|
||||
if (infer_running_count_ >= max_infer_running_ || !thread_pool_ || !run_flag_) {
|
||||
return;
|
||||
}
|
||||
|
||||
infer_running_count_++;
|
||||
|
||||
if (common_frame.img_frame.src_img.empty()) {
|
||||
infer_running_count_--;
|
||||
return;
|
||||
}
|
||||
|
||||
if (armor_detector_) {
|
||||
armor_detector_->pushInput(common_frame, std::nullopt);
|
||||
}
|
||||
|
||||
infer_running_count_--;
|
||||
}
|
||||
|
||||
void setDetectColor(bool flag) noexcept {
|
||||
detect_color_ = flag;
|
||||
}
|
||||
|
||||
void updateMainTracking(bool flag) noexcept {
|
||||
main_tracking_ = flag;
|
||||
}
|
||||
|
||||
int getBestTarget() noexcept {
|
||||
update();
|
||||
return best_target_;
|
||||
}
|
||||
|
||||
void
|
||||
ArmorDetectCallback(const std::vector<ArmorObject>& objs, const CommonFrame& frame) noexcept {
|
||||
auto one = std::any_cast<One::Ptr>(frame.any_ctx);
|
||||
std::vector<ArmorObject> sorted_objs;
|
||||
|
||||
for (const auto& obj: objs) {
|
||||
if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) {
|
||||
continue;
|
||||
}
|
||||
sorted_objs.push_back(obj);
|
||||
std::lock_guard<std::mutex> lock(active_results_mutex_);
|
||||
active_results_.emplace_back(obj, obj.confidence, one, frame.img_frame.timestamp);
|
||||
}
|
||||
update();
|
||||
Armors armors;
|
||||
armors.timestamp = frame.img_frame.timestamp;
|
||||
Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity();
|
||||
auto& car_b = ctx_.car_motion_buffer;
|
||||
auto& big_yaw_b = ctx_.big_yaw_motion_buffer;
|
||||
|
||||
if (car_b && big_yaw_b) {
|
||||
const auto t_query = armors.timestamp;
|
||||
|
||||
auto apply_motion = [&](const auto& att, const auto& att2) {
|
||||
R_gimbal2odom =
|
||||
Eigen::AngleAxisd(
|
||||
angles::normalize_angle(att2.data.big_yaw + one->yaw_in_big_yaw),
|
||||
Eigen::Vector3d::UnitZ()
|
||||
)
|
||||
* Eigen::AngleAxisd(0.0, Eigen::Vector3d::UnitY())
|
||||
* Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX());
|
||||
};
|
||||
|
||||
auto car_past_att = car_b->get_interpolated(t_query);
|
||||
auto big_yaw_past_att = big_yaw_b->get_interpolated(t_query);
|
||||
|
||||
if (car_past_att && big_yaw_past_att) {
|
||||
apply_motion(*car_past_att, *big_yaw_past_att);
|
||||
} else {
|
||||
auto last_att = car_b->get_last();
|
||||
auto last_big_yaw = big_yaw_b->get_last();
|
||||
|
||||
if (last_att && last_big_yaw) {
|
||||
apply_motion(*last_att, *last_big_yaw);
|
||||
}
|
||||
}
|
||||
}
|
||||
Eigen::Matrix3d R_camera2gimbal;
|
||||
R_camera2gimbal << 0.0, 0.0, 1.0, -1.0, -0.0, 0.0, 0.0, -1.0, 0.0;
|
||||
Eigen::Matrix4d T_camera_to_odom = utils::computeCameraToOdomTransform(
|
||||
R_gimbal2odom,
|
||||
R_camera2gimbal,
|
||||
Eigen::Vector3d::Zero()
|
||||
);
|
||||
armors.armors = one->armor_where->where(sorted_objs, T_camera_to_odom);
|
||||
for (auto& armor: armors.armors) {
|
||||
armor.timestamp = armors.timestamp;
|
||||
}
|
||||
auto& target = one->target;
|
||||
target = one->tracker->track(armors);
|
||||
|
||||
const auto now = std::chrono::steady_clock::now();
|
||||
|
||||
const auto latency_ms =
|
||||
wust_vl::common::utils::time_utils::durationMs(frame.img_frame.timestamp, now);
|
||||
latency_averager_->add(latency_ms);
|
||||
latency_ms_ = latency_averager_->average();
|
||||
detect_count_++;
|
||||
|
||||
printStats();
|
||||
}
|
||||
|
||||
void update() noexcept {
|
||||
std::lock_guard<std::mutex> lock(active_results_mutex_);
|
||||
|
||||
while (!active_results_.empty()) {
|
||||
auto& obj = active_results_.front();
|
||||
|
||||
if (std::abs(wust_vl::common::utils::time_utils::durationSec(
|
||||
obj.timestamp,
|
||||
wust_vl::common::utils::time_utils::now()
|
||||
))
|
||||
> active_time_)
|
||||
{
|
||||
active_results_.pop_front();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (active_results_.empty()) {
|
||||
best_target_ = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
double max_score = min_score_;
|
||||
best_target_ = -1;
|
||||
for (size_t i = 0; i < ones_.size(); ++i) {
|
||||
if (ones_[i]->total_score > max_score) {
|
||||
max_score = ones_[i]->total_score;
|
||||
best_target_ = ones_[i]->self_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
GimbalCmd solve(double bullet_speed) {
|
||||
GimbalCmd gimbal_cmd;
|
||||
std::optional<Target> target;
|
||||
int best_target = getBestTarget();
|
||||
if (best_target < 0) {
|
||||
target = std::nullopt;
|
||||
} else {
|
||||
target = ones_[best_target]->target;
|
||||
}
|
||||
auto& very_aimer = ctx_.very_aimer;
|
||||
if (!very_aimer) {
|
||||
return gimbal_cmd;
|
||||
}
|
||||
if (target.has_value() && target->checkTargetAppear()) {
|
||||
try {
|
||||
gimbal_cmd = very_aimer->veryAim(
|
||||
target.value(),
|
||||
bullet_speed,
|
||||
AutoAimFsm::AIM_WHOLE_CAR_CENTER
|
||||
);
|
||||
gimbal_cmd.enable_pitch_diff = 0.0;
|
||||
gimbal_cmd.enable_yaw_diff = 0.0;
|
||||
gimbal_cmd.fire_advice = false;
|
||||
} catch (...) {
|
||||
WUST_ERROR("omni") << "VeryAim error";
|
||||
}
|
||||
} else {
|
||||
gimbal_cmd.appear = false;
|
||||
}
|
||||
return gimbal_cmd;
|
||||
}
|
||||
void printStats() {
|
||||
utils::XSecOnce(
|
||||
[&] {
|
||||
WUST_INFO("armor_omni") << "det: " << detect_count_ << " best: " << best_target_
|
||||
<< " lat: " << latency_ms_;
|
||||
|
||||
detect_count_ = 0;
|
||||
},
|
||||
1.0
|
||||
);
|
||||
}
|
||||
|
||||
int fps_;
|
||||
int max_infer_running_ = 0;
|
||||
std::atomic<int> infer_running_count_ { 0 };
|
||||
|
||||
bool detect_color_;
|
||||
bool main_tracking_ = false;
|
||||
bool run_flag_ = false;
|
||||
|
||||
double active_time_ = 0;
|
||||
|
||||
std::deque<Obj> active_results_;
|
||||
|
||||
mutable std::mutex active_results_mutex_;
|
||||
|
||||
std::vector<One::Ptr> ones_;
|
||||
|
||||
YAML::Node config_;
|
||||
|
||||
std::unique_ptr<wust_vl::common::concurrency::ThreadPool> thread_pool_;
|
||||
|
||||
std::unique_ptr<wust_vl::common::utils::Timer> timer_;
|
||||
|
||||
ArmorDetectorBase::Ptr armor_detector_;
|
||||
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_;
|
||||
|
||||
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
|
||||
int best_target_ = -1;
|
||||
|
||||
int detect_count_ = 0;
|
||||
double latency_ms_;
|
||||
double min_score_ = 10.0;
|
||||
Ctx ctx_;
|
||||
};
|
||||
|
||||
ArmorOmni::ArmorOmni(bool detect_color_init, const Ctx& ctx):
|
||||
_impl(std::make_unique<Impl>(detect_color_init, ctx)) {}
|
||||
|
||||
ArmorOmni::~ArmorOmni() {
|
||||
_impl.reset();
|
||||
}
|
||||
|
||||
void ArmorOmni::start() {
|
||||
_impl->start();
|
||||
}
|
||||
|
||||
void ArmorOmni::setDetectColor(bool flag) {
|
||||
_impl->setDetectColor(flag);
|
||||
}
|
||||
|
||||
void ArmorOmni::updateMainTracking(bool flag) {
|
||||
_impl->updateMainTracking(flag);
|
||||
}
|
||||
|
||||
int ArmorOmni::getBestTarget() {
|
||||
return _impl->getBestTarget();
|
||||
}
|
||||
GimbalCmd ArmorOmni::solve(double bullet_speed) {
|
||||
return _impl->solve(bullet_speed);
|
||||
}
|
||||
|
||||
} // namespace wust_vision::auto_aim
|
||||
36
wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.hpp
Normal file
36
wust_vision-main/tasks/auto_aim/armor_omni/armor_omni.hpp
Normal file
@@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
|
||||
#include "tasks/type_common.hpp"
|
||||
#include <memory>
|
||||
#include <wust_vl/common/utils/motion_buffer.hpp>
|
||||
#include <yaml-cpp/node/node.h>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorOmni {
|
||||
public:
|
||||
struct Ctx {
|
||||
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<CarMotion, 1024>>
|
||||
car_motion_buffer;
|
||||
std::shared_ptr<wust_vl::common::utils::MotionBufferGeneric<BigYaw, 1024>>
|
||||
big_yaw_motion_buffer;
|
||||
VeryAimer::Ptr very_aimer;
|
||||
};
|
||||
static constexpr const char* OMNI_CONFIG = "config/omni/omni.yaml";
|
||||
using Ptr = std::unique_ptr<ArmorOmni>;
|
||||
ArmorOmni(bool detect_color_init, const Ctx& ctx);
|
||||
static Ptr create(bool detect_color_init, const Ctx& ctx) {
|
||||
return std::make_unique<ArmorOmni>(detect_color_init, ctx);
|
||||
}
|
||||
~ArmorOmni();
|
||||
void start();
|
||||
void setDetectColor(bool flag);
|
||||
void updateMainTracking(bool flag);
|
||||
int getBestTarget();
|
||||
GimbalCmd solve(double bullet_speed);
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
@@ -0,0 +1,286 @@
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
// ceres
|
||||
#include <ceres/ceres.h>
|
||||
#include <opencv2/calib3d.hpp>
|
||||
#include <opencv2/core/mat.hpp>
|
||||
// project
|
||||
#include "KalmanHyLib/kalman_hybird_lib.hpp"
|
||||
|
||||
namespace crazy {
|
||||
enum class MotionModel {
|
||||
CONSTANT_VELOCITY = 0, // Constant velocity
|
||||
CONSTANT_ROTATION = 1, // Constant rotation velocity
|
||||
CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity
|
||||
};
|
||||
|
||||
constexpr int X_N = 11, Z_N = 8;
|
||||
using VecZ = Eigen::Matrix<double, Z_N, 1>;
|
||||
using VecX = Eigen::Matrix<double, X_N, 1>;
|
||||
enum class Mean : uint8_t {
|
||||
PLBX = 0,
|
||||
PLBY = 1,
|
||||
PLTX = 2,
|
||||
PLTY = 3,
|
||||
PRTX = 4,
|
||||
PRTY = 5,
|
||||
PRBX = 6,
|
||||
PRBY = 7,
|
||||
// IMUY = 4,
|
||||
// IMUP = 5,
|
||||
// IMUR = 6,
|
||||
Z_N = 8
|
||||
};
|
||||
enum class State : uint8_t {
|
||||
CX = 0,
|
||||
VCX = 1,
|
||||
CY = 2,
|
||||
VCY = 3,
|
||||
CZ = 4,
|
||||
VCZ = 5,
|
||||
YAW = 6,
|
||||
VYAW = 7,
|
||||
R = 8,
|
||||
L = 9,
|
||||
H = 10,
|
||||
outpost01DZ = 9,
|
||||
outpost02DZ = 10,
|
||||
X_N = 11
|
||||
};
|
||||
struct Predict {
|
||||
Predict() = default;
|
||||
explicit Predict(
|
||||
double dt,
|
||||
MotionModel model = MotionModel::CONSTANT_VEL_ROT,
|
||||
double vrx = 0.0,
|
||||
double vry = 0.0,
|
||||
double vrz = 0.0
|
||||
):
|
||||
dt(dt),
|
||||
model(model),
|
||||
vrx(vrx),
|
||||
vry(vry),
|
||||
vrz(vrz) {}
|
||||
|
||||
template<typename T>
|
||||
void operator()(const T x0[X_N], T x1[X_N]) const {
|
||||
for (int i = 0; i < X_N; i++) {
|
||||
x1[i] = x0[i];
|
||||
}
|
||||
|
||||
// v_xyz
|
||||
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) {
|
||||
// linear velocity
|
||||
x1[(int)State::CX] += x0[(int)State::VCX] * T(dt);
|
||||
x1[(int)State::CY] += x0[(int)State::VCY] * T(dt);
|
||||
x1[(int)State::CZ] += x0[(int)State::VCZ] * T(dt);
|
||||
} else {
|
||||
// no velocity
|
||||
x1[(int)State::VCX] *= T(0.);
|
||||
x1[(int)State::VCY] *= T(0.);
|
||||
x1[(int)State::VCZ] *= T(0.);
|
||||
}
|
||||
|
||||
x1[(int)State::CX] -= T(vrx) * T(dt);
|
||||
x1[(int)State::CY] -= T(vry) * T(dt);
|
||||
x1[(int)State::CZ] -= T(vrz) * T(dt);
|
||||
|
||||
x1[(int)State::VCZ] *= T(0.);
|
||||
// v_yaw
|
||||
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) {
|
||||
// angular velocity
|
||||
x1[(int)State::YAW] += x0[(int)State::VYAW] * T(dt);
|
||||
} else {
|
||||
// no rotation
|
||||
x1[(int)State::VYAW] *= T(0.);
|
||||
}
|
||||
}
|
||||
|
||||
double dt;
|
||||
MotionModel model;
|
||||
double vrx, vry, vrz;
|
||||
};
|
||||
constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135
|
||||
constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
|
||||
constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0;
|
||||
constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
|
||||
|
||||
constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180;
|
||||
|
||||
template<typename T>
|
||||
T normalize_angle_t(T angle) {
|
||||
T two_pi = T(2.0 * M_PI);
|
||||
return angle - two_pi * floor((angle + T(M_PI)) / two_pi);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Eigen::Quaternion<T>
|
||||
eulerToQuat(const Eigen::Vector<T, 3>& euler, int axis0, int axis1, int axis2, bool extrinsic) {
|
||||
T rz = euler[0];
|
||||
T ry = euler[1];
|
||||
T rx = euler[2];
|
||||
|
||||
Eigen::Quaternion<T> qx(Eigen::AngleAxis<T>(rx, Eigen::Vector3<T>::UnitX()));
|
||||
Eigen::Quaternion<T> qy(Eigen::AngleAxis<T>(ry, Eigen::Vector3<T>::UnitY()));
|
||||
Eigen::Quaternion<T> qz(Eigen::AngleAxis<T>(rz, Eigen::Vector3<T>::UnitZ()));
|
||||
|
||||
if (!extrinsic)
|
||||
std::swap(axis0, axis2);
|
||||
|
||||
Eigen::Quaternion<T> q;
|
||||
|
||||
if (axis0 == 0 && axis1 == 1 && axis2 == 2)
|
||||
q = qx * qy * qz;
|
||||
else if (axis0 == 0 && axis1 == 2 && axis2 == 1)
|
||||
q = qx * qz * qy;
|
||||
else if (axis0 == 1 && axis1 == 0 && axis2 == 2)
|
||||
q = qy * qx * qz;
|
||||
else if (axis0 == 1 && axis1 == 2 && axis2 == 0)
|
||||
q = qy * qz * qx;
|
||||
else if (axis0 == 2 && axis1 == 0 && axis2 == 1)
|
||||
q = qz * qx * qy;
|
||||
else if (axis0 == 2 && axis1 == 1 && axis2 == 0)
|
||||
q = qz * qy * qx;
|
||||
else
|
||||
throw std::invalid_argument("Unsupported axis order");
|
||||
|
||||
return q;
|
||||
}
|
||||
|
||||
template<typename PointType, typename U>
|
||||
std::vector<PointType> buildObjectPoints(const U& w, const U& h) noexcept {
|
||||
using T = U;
|
||||
T half2 = T(2.0);
|
||||
return { PointType(T(0), w / half2, -h / half2),
|
||||
PointType(T(0), w / half2, h / half2),
|
||||
PointType(T(0), -w / half2, h / half2),
|
||||
PointType(T(0), -w / half2, -h / half2) };
|
||||
}
|
||||
|
||||
struct Measure {
|
||||
struct MeasureCtx {
|
||||
int armor_num = 4;
|
||||
int id = 0;
|
||||
Eigen::Matrix4d T_odom_to_camera_d;
|
||||
cv::Mat camera_intrinsic;
|
||||
cv::Mat camera_distortion;
|
||||
bool is_big;
|
||||
} ctx;
|
||||
|
||||
Measure() = default;
|
||||
explicit Measure(const MeasureCtx& c): ctx(c) {}
|
||||
|
||||
template<typename T>
|
||||
void operator()(const T x[X_N], T z[Z_N]) const {
|
||||
T id_t = T(ctx.id);
|
||||
T num_t = T(ctx.armor_num);
|
||||
T two = T(2.0);
|
||||
T angle = normalize_angle_t(x[(int)State::YAW] + id_t * two * T(M_PI) / num_t);
|
||||
|
||||
bool outpost = (ctx.armor_num == 3);
|
||||
bool use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3);
|
||||
|
||||
T r = use_l_h ? x[(int)State::R] + x[(int)State::L] : x[(int)State::R];
|
||||
|
||||
T armor_x = x[(int)State::CX] - ceres::cos(angle) * r;
|
||||
T armor_y = x[(int)State::CY] - ceres::sin(angle) * r;
|
||||
T armor_z = outpost ? getoutpost_armor_z(x)
|
||||
: use_l_h ? x[(int)State::CZ] + x[(int)State::H]
|
||||
: x[(int)State::CZ];
|
||||
|
||||
Eigen::Vector3<T> euler_odom;
|
||||
euler_odom[0] = angle; //yaw
|
||||
euler_odom[1] = outpost ? T(-FIFTTEN_DEGREE_RAD) : T(FIFTTEN_DEGREE_RAD); //pitch
|
||||
euler_odom[2] = T(M_PI / 2.0); //roll
|
||||
|
||||
Eigen::Quaternion<T> q_odom = eulerToQuat(euler_odom, 2, 1, 0, true);
|
||||
|
||||
Eigen::Matrix4<T> T_odom_to_camera = ctx.T_odom_to_camera_d.cast<T>();
|
||||
|
||||
Eigen::Vector4<T> pos_odom4(armor_x, armor_y, armor_z, T(1.0));
|
||||
Eigen::Vector4<T> pos_camera4 = T_odom_to_camera * pos_odom4;
|
||||
Eigen::Vector3<T> pos_camera = pos_camera4.template head<3>();
|
||||
|
||||
Eigen::Matrix3<T> R_odom_to_camera = T_odom_to_camera.block(0, 0, 3, 3).template cast<T>();
|
||||
Eigen::Matrix3<T> R_ori_odom = q_odom.normalized().toRotationMatrix();
|
||||
Eigen::Matrix3<T> R_camera = R_odom_to_camera * R_ori_odom;
|
||||
Eigen::Quaternion<T> q_camera(R_camera);
|
||||
q_camera.normalize();
|
||||
|
||||
T w3 = ctx.is_big ? T(LARGE_ARMOR_WIDTH) : T(SMALL_ARMOR_WIDTH);
|
||||
T h3 = ctx.is_big ? T(LARGE_ARMOR_HEIGHT) : T(SMALL_ARMOR_HEIGHT);
|
||||
|
||||
auto objPts = buildObjectPoints<Eigen::Matrix<T, 3, 1>>(w3, h3);
|
||||
|
||||
Eigen::Matrix3<T> R = q_camera.toRotationMatrix();
|
||||
Eigen::Matrix<T, 3, 1> t = pos_camera;
|
||||
|
||||
std::vector<Eigen::Matrix<T, 3, 1>> Pc;
|
||||
Pc.reserve(objPts.size());
|
||||
for (const auto& p: objPts) {
|
||||
Eigen::Matrix<T, 3, 1> v = p;
|
||||
Pc.push_back(R * v + t);
|
||||
}
|
||||
|
||||
const cv::Mat& K = ctx.camera_intrinsic;
|
||||
T fx = T(K.at<double>(0, 0));
|
||||
T fy = T(K.at<double>(1, 1));
|
||||
T cx = T(K.at<double>(0, 2));
|
||||
T cy = T(K.at<double>(1, 2));
|
||||
|
||||
std::array<T, 4> u, v;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
T Xc = Pc[i][0];
|
||||
T Yc = Pc[i][1];
|
||||
T Zc = Pc[i][2];
|
||||
|
||||
u[i] = fx * (Xc / Zc) + cx;
|
||||
v[i] = fy * (Yc / Zc) + cy;
|
||||
}
|
||||
|
||||
z[0] = u[0];
|
||||
z[1] = v[0];
|
||||
z[2] = u[1];
|
||||
z[3] = v[1];
|
||||
z[4] = u[2];
|
||||
z[5] = v[2];
|
||||
z[6] = u[3];
|
||||
z[7] = v[3];
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T getoutpost_armor_z(const T x[X_N]) const {
|
||||
if (ctx.id == 0)
|
||||
return x[(int)State::CZ];
|
||||
if (ctx.id == 1)
|
||||
return x[(int)State::CZ] + x[(int)State::outpost01DZ];
|
||||
if (ctx.id == 2)
|
||||
return x[(int)State::CZ] + x[(int)State::outpost02DZ];
|
||||
return x[(int)State::CZ];
|
||||
}
|
||||
|
||||
using VecX = Eigen::Matrix<double, X_N, 1>;
|
||||
using VecZ = Eigen::Matrix<double, Z_N, 1>;
|
||||
|
||||
void h(const VecX& x, VecZ& z) const {
|
||||
operator()(x.data(), z.data());
|
||||
}
|
||||
};
|
||||
|
||||
using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
|
||||
using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
|
||||
|
||||
} // namespace crazy
|
||||
@@ -0,0 +1,226 @@
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
// ceres
|
||||
#include <algorithm>
|
||||
#include <ceres/ceres.h>
|
||||
// project
|
||||
#include "KalmanHyLib/kalman_hybird_lib.hpp"
|
||||
namespace ypdv2armor_motion_model {
|
||||
|
||||
enum class MotionModel {
|
||||
CONSTANT_VELOCITY = 0, // Constant velocity
|
||||
CONSTANT_ROTATION = 1, // Constant rotation velocity
|
||||
CONSTANT_VEL_ROT = 2 // Constant velocity and rotation velocity
|
||||
};
|
||||
|
||||
// X_N: state dimension, Z_N: measurement dimension
|
||||
constexpr int X_N = 11, Z_N = 4;
|
||||
using VecZ = Eigen::Matrix<double, Z_N, 1>;
|
||||
using VecX = Eigen::Matrix<double, X_N, 1>;
|
||||
enum class MeasureID : uint8_t { YPD_Y = 0, YPD_P = 1, YPD_D = 2, ORI_YAW = 3, Z_N = 4 };
|
||||
enum class StateID : uint8_t {
|
||||
CX = 0,
|
||||
VCX = 1,
|
||||
CY = 2,
|
||||
VCY = 3,
|
||||
CZ = 4,
|
||||
VCZ = 5,
|
||||
YAW = 6,
|
||||
VYAW = 7,
|
||||
R = 8,
|
||||
L = 9,
|
||||
H = 10,
|
||||
outpost01DZ = 9,
|
||||
outpost02DZ = 10,
|
||||
X_N = 11
|
||||
};
|
||||
struct State {
|
||||
VecX x;
|
||||
[[nodiscard]] inline double cx() const noexcept {
|
||||
return x((int)StateID::CX);
|
||||
}
|
||||
[[nodiscard]] inline double cy() const noexcept {
|
||||
return x((int)StateID::CY);
|
||||
}
|
||||
[[nodiscard]] inline double cz() const noexcept {
|
||||
return x((int)StateID::CZ);
|
||||
}
|
||||
[[nodiscard]] inline Eigen::Vector3d pos() const noexcept {
|
||||
return Eigen::Vector3d(cx(), cy(), cz());
|
||||
}
|
||||
[[nodiscard]] inline double vcx() const noexcept {
|
||||
return x((int)StateID::VCX);
|
||||
}
|
||||
[[nodiscard]] inline double vcy() const noexcept {
|
||||
return x((int)StateID::VCY);
|
||||
}
|
||||
[[nodiscard]] inline double vcz() const noexcept {
|
||||
return x((int)StateID::VCZ);
|
||||
}
|
||||
[[nodiscard]] inline Eigen::Vector3d vel() const noexcept {
|
||||
return Eigen::Vector3d(vcx(), vcy(), vcz());
|
||||
}
|
||||
[[nodiscard]] inline double vyaw() const noexcept {
|
||||
return x((int)StateID::VYAW);
|
||||
}
|
||||
[[nodiscard]] inline double yaw() const noexcept {
|
||||
return x((int)StateID::YAW);
|
||||
}
|
||||
[[nodiscard]] inline double r() const noexcept {
|
||||
return x((int)StateID::R);
|
||||
}
|
||||
[[nodiscard]] inline double l() const noexcept {
|
||||
return x((int)StateID::L);
|
||||
}
|
||||
[[nodiscard]] inline double h() const noexcept {
|
||||
return x((int)StateID::H);
|
||||
}
|
||||
[[nodiscard]] inline double outpost01DZ() const noexcept {
|
||||
return x((int)StateID::outpost01DZ);
|
||||
}
|
||||
[[nodiscard]] inline double outpost02DZ() const noexcept {
|
||||
return x((int)StateID::outpost02DZ);
|
||||
}
|
||||
};
|
||||
struct Predict {
|
||||
Predict() = default;
|
||||
explicit Predict(double dt, MotionModel model = MotionModel::CONSTANT_VEL_ROT):
|
||||
dt(dt),
|
||||
model(model) {}
|
||||
|
||||
template<typename T>
|
||||
void operator()(const T x0[X_N], T x1[X_N]) const {
|
||||
for (int i = 0; i < X_N; i++) {
|
||||
x1[i] = x0[i];
|
||||
}
|
||||
|
||||
// v_xyz
|
||||
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_VELOCITY) {
|
||||
// linear velocity
|
||||
x1[(int)StateID::CX] += x0[(int)StateID::VCX] * T(dt);
|
||||
x1[(int)StateID::CY] += x0[(int)StateID::VCY] * T(dt);
|
||||
x1[(int)StateID::CZ] += x0[(int)StateID::VCZ] * T(dt);
|
||||
} else {
|
||||
// no velocity
|
||||
x1[(int)StateID::VCX] *= T(0.);
|
||||
x1[(int)StateID::VCY] *= T(0.);
|
||||
x1[(int)StateID::VCZ] *= T(0.);
|
||||
}
|
||||
|
||||
// v_yaw
|
||||
if (model == MotionModel::CONSTANT_VEL_ROT || model == MotionModel::CONSTANT_ROTATION) {
|
||||
// angular velocity
|
||||
x1[(int)StateID::YAW] += x0[(int)StateID::VYAW] * T(dt);
|
||||
} else {
|
||||
// no rotation
|
||||
x1[(int)StateID::VYAW] *= T(0.);
|
||||
}
|
||||
clampState(x1);
|
||||
}
|
||||
template<typename T>
|
||||
void clampState(T x1[X_N]) const {
|
||||
auto& r = x1[(int)StateID::R];
|
||||
auto& l = x1[(int)StateID::L];
|
||||
r = std::clamp(r, T(0.1), T(0.5));
|
||||
if (r < T(0.1) || r > T(0.5)) {
|
||||
r = T(0.25);
|
||||
l = T(0);
|
||||
}
|
||||
T sum = r + l;
|
||||
if (sum < T(0.1) || sum > T(0.5)) {
|
||||
r = T(0.25);
|
||||
l = T(0);
|
||||
}
|
||||
auto& h = x1[(int)StateID::H];
|
||||
h = std::clamp(h, T(-0.5), T(0.5));
|
||||
}
|
||||
void f(const VecX& x0, VecX& x1) const {
|
||||
assert(x0.size() == X_N);
|
||||
assert(x1.size() == X_N);
|
||||
operator()(x0.data(), x1.data());
|
||||
}
|
||||
|
||||
double dt;
|
||||
MotionModel model;
|
||||
};
|
||||
template<typename T>
|
||||
T normalize_angle_t(T angle) {
|
||||
T two_pi = T(2.0 * M_PI);
|
||||
return angle - two_pi * floor((angle + T(M_PI)) / two_pi);
|
||||
}
|
||||
|
||||
struct Measure {
|
||||
struct MeasureCtx {
|
||||
MeasureCtx() = default;
|
||||
MeasureCtx(int id, int armor_num): armor_num(armor_num), id(id) {}
|
||||
int armor_num = 4;
|
||||
int id = 0;
|
||||
} ctx;
|
||||
Measure() = default;
|
||||
explicit Measure(MeasureCtx c): ctx(c) {}
|
||||
template<typename T>
|
||||
void operator()(const T x[X_N], T z[Z_N]) const {
|
||||
// Compute armor position
|
||||
auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x);
|
||||
T xy_dist = ceres::sqrt(armor_x * armor_x + armor_y * armor_y);
|
||||
T dist = ceres::sqrt(xy_dist * xy_dist + armor_z * armor_z);
|
||||
// Observation model
|
||||
z[(int)MeasureID::YPD_Y] = ceres::atan2(armor_y, armor_x); // yaw
|
||||
z[(int)MeasureID::YPD_P] = ceres::atan2(armor_z, xy_dist); // pitch
|
||||
z[(int)MeasureID::YPD_D] = dist; // distance
|
||||
z[(int)MeasureID::ORI_YAW] = angle; // orientation_yaw
|
||||
}
|
||||
template<typename T>
|
||||
T get_angle(const T x[X_N]) const {
|
||||
return normalize_angle_t(x[(int)StateID::YAW] + ctx.id * 2 * M_PI / ctx.armor_num);
|
||||
}
|
||||
template<typename T>
|
||||
std::tuple<T, T, T, T> h_armor_xyza(const T x[X_N]) const {
|
||||
T angle = get_angle(x);
|
||||
auto outpost = ctx.armor_num == 3;
|
||||
auto use_l_h = (ctx.armor_num == 4) && (ctx.id == 1 || ctx.id == 3);
|
||||
T r = (use_l_h) ? x[(int)StateID::R] + x[(int)StateID::L] : x[(int)StateID::R];
|
||||
T armor_x = x[(int)StateID::CX] - ceres::cos(angle) * r;
|
||||
T armor_y = x[(int)StateID::CY] - ceres::sin(angle) * r;
|
||||
T armor_z = (outpost) ? getoutpost_armor_z(x)
|
||||
: (use_l_h) ? x[(int)StateID::CZ] + x[(int)StateID::H]
|
||||
: x[(int)StateID::CZ];
|
||||
return { armor_x, armor_y, armor_z, angle };
|
||||
}
|
||||
Eigen::Vector4d h_armor_xyza(const VecX& x) const {
|
||||
assert(x.size() == X_N);
|
||||
auto [armor_x, armor_y, armor_z, angle] = h_armor_xyza(x.data());
|
||||
|
||||
return { armor_x, armor_y, armor_z, angle };
|
||||
}
|
||||
template<typename T>
|
||||
T getoutpost_armor_z(const T x[X_N]) const {
|
||||
return (ctx.id == 0) ? x[(int)StateID::CZ]
|
||||
: (ctx.id == 1) ? x[(int)StateID::CZ] + x[(int)StateID::outpost01DZ]
|
||||
: (ctx.id == 2) ? x[(int)StateID::CZ] + x[(int)StateID::outpost02DZ]
|
||||
: x[(int)StateID::CZ];
|
||||
}
|
||||
void h(const VecX& x, VecZ& z) const {
|
||||
assert(x.size() == X_N);
|
||||
assert(z.size() == Z_N);
|
||||
operator()(x.data(), z.data());
|
||||
}
|
||||
};
|
||||
|
||||
using RobotStateEKF = kalman_hybird_lib::ExtendedKalmanFilter<X_N, Z_N, Predict, Measure>;
|
||||
using RobotStateESEKF = kalman_hybird_lib::ErrorStateEKF<X_N, Z_N, Predict, Measure>;
|
||||
|
||||
} // namespace ypdv2armor_motion_model
|
||||
379
wust_vision-main/tasks/auto_aim/armor_tracker/target.cpp
Normal file
379
wust_vision-main/tasks/auto_aim/armor_tracker/target.cpp
Normal file
@@ -0,0 +1,379 @@
|
||||
#include "target.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
Target::Target() {
|
||||
target_state_.x = Eigen::VectorXd::Zero(MModel::X_N);
|
||||
}
|
||||
Target::Target(const Armor& a, TargetConfig::Ptr target_config) {
|
||||
Eigen::DiagonalMatrix<double, ypdv2armor_motion_model::X_N> p0;
|
||||
if (a.number == ArmorNumber::OUTPOST) {
|
||||
p0.diagonal() << 1, 64, 1, 64, 1, 81, 0.4, 100, 1e-4, 0.1, 0.1;
|
||||
armor_num_ = 3;
|
||||
radius_pre_ = 0.2765;
|
||||
} else if (a.number == ArmorNumber::BASE) {
|
||||
p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1e-4, 0, 0;
|
||||
armor_num_ = 3;
|
||||
radius_pre_ = 0.3205;
|
||||
} else {
|
||||
p0.diagonal() << 1, 64, 1, 64, 1, 64, 0.4, 100, 1, 1, 1;
|
||||
armor_num_ = 4;
|
||||
radius_pre_ = 0.2;
|
||||
}
|
||||
target_config_ = target_config;
|
||||
const auto yfv2 = MModel::Predict(0.005);
|
||||
ctx_.armor_num = armor_num_;
|
||||
ctx_.id = 0;
|
||||
const auto yhv2 = MModel::Measure(ctx_);
|
||||
const auto yu_qv2 = [this]() {
|
||||
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
|
||||
return q;
|
||||
};
|
||||
|
||||
const auto yu_rv2 = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
|
||||
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
|
||||
return r;
|
||||
};
|
||||
esekf_ypd_ = MModel::RobotStateESEKF(yfv2, yhv2, yu_qv2, yu_rv2, p0);
|
||||
|
||||
esekf_ypd_.setResidualFunc([this](
|
||||
const Eigen::Matrix<double, MModel::Z_N, 1>& z_pred,
|
||||
const Eigen::Matrix<double, MModel::Z_N, 1>& z
|
||||
) {
|
||||
Eigen::Matrix<double, MModel::Z_N, 1> r = z - z_pred;
|
||||
r[0] = angles::shortest_angular_distance(
|
||||
z_pred[(int)MModel::MeasureID::YPD_Y],
|
||||
z[(int)MModel::MeasureID::YPD_Y]
|
||||
); // yaw
|
||||
r[3] = angles::shortest_angular_distance(
|
||||
z_pred[(int)MModel::MeasureID::ORI_YAW],
|
||||
z[(int)MModel::MeasureID::ORI_YAW]
|
||||
); // ori_yaw
|
||||
return r;
|
||||
});
|
||||
esekf_ypd_.setIterationNum(target_config_->esekf_iter_num_param.get());
|
||||
esekf_ypd_.setInjectFunc([this](
|
||||
const Eigen::Matrix<double, MModel::X_N, 1>& delta,
|
||||
Eigen::Matrix<double, MModel::X_N, 1>& nominal
|
||||
) {
|
||||
for (int i = 0; i < MModel::X_N; i++) {
|
||||
if (i == (int)MModel::StateID::YAW)
|
||||
continue;
|
||||
nominal[i] += delta[i];
|
||||
}
|
||||
nominal[(int)MModel::StateID::YAW] = angles::normalize_angle(
|
||||
nominal[(int)MModel::StateID::YAW] + delta[(int)MModel::StateID::YAW]
|
||||
);
|
||||
});
|
||||
|
||||
const double xa = a.target_pos.x();
|
||||
const double ya = a.target_pos.y();
|
||||
const double za = a.target_pos.z();
|
||||
const double yaw = utils::orientationToYaw<Target>(a.target_ori);
|
||||
|
||||
target_state_.x = Eigen::VectorXd::Zero(MModel::X_N);
|
||||
const double r = radius_pre_;
|
||||
const double xc = xa + r * cos(yaw);
|
||||
const double yc = ya + r * sin(yaw);
|
||||
const double zc = za;
|
||||
target_state_.x << xc, 0, yc, 0, zc, 0, yaw, 0, r, 0, 0;
|
||||
esekf_ypd_.setState(target_state_.x);
|
||||
tracked_id_ = a.number;
|
||||
type_ = a.type;
|
||||
last_t_ = a.timestamp;
|
||||
timestamp_ = a.timestamp;
|
||||
is_inited = true;
|
||||
}
|
||||
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
|
||||
Target::computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z
|
||||
) const noexcept {
|
||||
Eigen::Matrix<double, MModel::Z_N, MModel::Z_N> r;
|
||||
const double delta_angle = angles::normalize_angle(z[3] - z[0]);
|
||||
const double abs_delta = std::abs(delta_angle);
|
||||
|
||||
// sin插值函数,小值慢、大值快
|
||||
const auto sinInterp = [](double x, double x0, double x1, double y0, double y1) -> double {
|
||||
double t = (x - x0) / (x1 - x0);
|
||||
if (t < 0)
|
||||
t = 0;
|
||||
if (t > 1)
|
||||
t = 1;
|
||||
double s = std::sin(t * M_PI / 2.0);
|
||||
return y0 + s * (y1 - y0);
|
||||
};
|
||||
// clang-format off
|
||||
r <<target_config_->yp_r_param.get(), 0, 0, 0,
|
||||
0, target_config_->yp_r_param.get() , 0, 0,
|
||||
0, 0, sinInterp(abs_delta, 0.0, M_PI/2.0, target_config_->dis_r_front_param.get(), target_config_->dis_r_side_param.get())+z[2]*z[2]*target_config_->dis2_r_ratio_param.get(), 0,
|
||||
0, 0, 0,log(std::abs(z[2]) + 1) *target_config_->yaw_r_log_ratio_param.get() + sinInterp(M_PI/2.0-abs_delta, 0.0, M_PI/2.0, target_config_->yaw_r_base_side_param.get(), target_config_->yaw_r_base_front_param.get());
|
||||
// clang-format on
|
||||
return r;
|
||||
}
|
||||
Eigen::Matrix<double, MModel::X_N, MModel::X_N> Target::computeProcessNoise(double dt
|
||||
) const noexcept {
|
||||
Eigen::Matrix<double, MModel::X_N, MModel::X_N> q;
|
||||
Eigen::Vector3d q_xyz;
|
||||
double q_yaw;
|
||||
double q_l, q_h;
|
||||
if (tracked_id_ == ArmorNumber::OUTPOST) {
|
||||
q_xyz = target_config_->qxyz_output; // 前哨站加速度方差
|
||||
q_yaw = target_config_->qyaw_output_param.get(); // 前哨站角加速度方差
|
||||
q_l = target_config_->q_outpost_dz_param.get();
|
||||
q_h = target_config_->q_outpost_dz_param.get();
|
||||
} else {
|
||||
q_xyz = target_config_->qxyz_common; // 加速度方差
|
||||
q_yaw = target_config_->qyaw_common_param.get(); // 角加速度方差
|
||||
q_l = target_config_->q_l_param.get();
|
||||
q_h = target_config_->q_h_param.get();
|
||||
}
|
||||
const double t = dt;
|
||||
const double q_x_x = pow(t, 4) / 4 * q_xyz.x(), q_x_vx = pow(t, 3) / 2 * q_xyz.x(),
|
||||
q_vx_vx = pow(t, 2) * q_xyz.x();
|
||||
const double q_y_y = pow(t, 4) / 4 * q_xyz.y(), q_y_vy = pow(t, 3) / 2 * q_xyz.y(),
|
||||
q_vy_vy = pow(t, 2) * q_xyz.y();
|
||||
const double q_z_z = pow(t, 4) / 4 * q_xyz.z(), q_z_vz = pow(t, 3) / 2 * q_xyz.z(),
|
||||
q_vz_vz = pow(t, 2) * q_xyz.z();
|
||||
const double q_yaw_yaw = pow(t, 4) / 4 * q_yaw, q_yaw_vyaw = pow(t, 3) / 2 * q_yaw,
|
||||
q_vyaw_vyaw = pow(t, 2) * q_yaw;
|
||||
const double q_r = target_config_->q_r_param.get();
|
||||
|
||||
// clang-format off
|
||||
// xc v_xc yc v_yc zc v_zc yaw v_yaw r l h
|
||||
q << q_x_x, q_x_vx, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
q_x_vx, q_vx_vx,0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, q_y_y, q_y_vy, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, q_y_vy, q_vy_vy,0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, q_z_z, q_z_vz, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, q_z_vz, q_vz_vz,0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, q_yaw_yaw, q_yaw_vyaw, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, q_yaw_vyaw, q_vyaw_vyaw,0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, q_r, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, q_l, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, q_h;
|
||||
// clang-format on
|
||||
return q;
|
||||
}
|
||||
MModel::Predict Target::getPredictFunc(double dt) const noexcept {
|
||||
MModel::Predict predict_func;
|
||||
if (tracked_id_ == ArmorNumber::OUTPOST) {
|
||||
predict_func = MModel::Predict {
|
||||
dt,
|
||||
MModel::MotionModel::CONSTANT_ROTATION,
|
||||
};
|
||||
} else {
|
||||
predict_func = MModel::Predict {
|
||||
dt,
|
||||
MModel::MotionModel::CONSTANT_VEL_ROT,
|
||||
};
|
||||
}
|
||||
return predict_func;
|
||||
}
|
||||
void Target::predict(std::chrono::steady_clock::time_point t) noexcept {
|
||||
const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
|
||||
|
||||
predict(dt);
|
||||
|
||||
last_t_ = t;
|
||||
}
|
||||
void Target::predict(double dt) noexcept {
|
||||
MModel::Predict predict_func = getPredictFunc(dt);
|
||||
|
||||
esekf_ypd_.setPredictFunc(predict_func);
|
||||
const auto yu_qv2 = [dt, this]() { return computeProcessNoise(dt); };
|
||||
|
||||
esekf_ypd_.setUpdateQ(yu_qv2);
|
||||
|
||||
target_state_.x = esekf_ypd_.predict();
|
||||
if (target_state_.pos().norm() < 0.5) {
|
||||
is_tracking = false;
|
||||
}
|
||||
}
|
||||
void Target::predictSimple(std::chrono::steady_clock::time_point t) noexcept {
|
||||
const double dt = wust_vl::common::utils::time_utils::durationSec(last_t_, t);
|
||||
|
||||
predictSimple(dt);
|
||||
|
||||
last_t_ = t;
|
||||
}
|
||||
void Target::predictSimple(double dt) noexcept {
|
||||
MModel::Predict predict_func = getPredictFunc(dt);
|
||||
predict_func.f(target_state_.x, target_state_.x);
|
||||
if (target_state_.pos().norm() < 0.5) {
|
||||
is_tracking = false;
|
||||
}
|
||||
}
|
||||
bool Target::update(const std::pair<int, Armor>& a) noexcept {
|
||||
const auto armor = a.second;
|
||||
const auto id = a.first;
|
||||
const auto yu_rv2 = [this](const Eigen::Matrix<double, MModel::Z_N, 1>& z) {
|
||||
return this->computeMeasurementCovariance(z);
|
||||
};
|
||||
esekf_ypd_.setUpdateR(yu_rv2);
|
||||
measurement_ = getMeasure(armor);
|
||||
|
||||
if (id != 0)
|
||||
jumped = true;
|
||||
|
||||
ctx_.id = id;
|
||||
esekf_ypd_.setMeasureFunc(MModel::Measure { ctx_ });
|
||||
|
||||
target_state_.x = esekf_ypd_.update(measurement_);
|
||||
timestamp_ = armor.timestamp;
|
||||
last_t_ = timestamp_;
|
||||
return true;
|
||||
}
|
||||
cv::Rect Target::expanded(
|
||||
Eigen::Matrix4d T_camera_to_odom,
|
||||
const cv::Mat& camera_intrinsic,
|
||||
const cv::Mat& camera_distortion,
|
||||
const cv::Size& image_size
|
||||
) const noexcept {
|
||||
const double dt = wust_vl::common::utils::time_utils::durationSec(
|
||||
timestamp_,
|
||||
wust_vl::common::utils::time_utils::now()
|
||||
);
|
||||
if (!is_inited || dt > target_config_->lost_time_thres_param.get()) {
|
||||
return cv::Rect(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
const float car_box_half =
|
||||
std::max(target_state_.r(), target_state_.r() + target_state_.l()) + 0.15;
|
||||
|
||||
static std::vector<cv::Point3f> CAR_BOX;
|
||||
CAR_BOX = { { 0, car_box_half, -car_box_half },
|
||||
{ 0, -car_box_half, -car_box_half },
|
||||
{ 0, -car_box_half, car_box_half },
|
||||
{ 0, car_box_half, car_box_half } };
|
||||
|
||||
const Eigen::Matrix4d T_odom_to_camera = T_camera_to_odom.inverse();
|
||||
const Eigen::Vector4d
|
||||
pos_odom(target_state_.cx(), target_state_.cy(), target_state_.cz(), 1.0);
|
||||
const Eigen::Vector4d pos_cam = T_odom_to_camera * pos_odom;
|
||||
|
||||
if (pos_cam.z() <= 0.2) {
|
||||
return cv::Rect(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << pos_cam.x(), pos_cam.y(), pos_cam.z());
|
||||
|
||||
Eigen::Vector3d euler;
|
||||
euler.z() = M_PI / 2.0;
|
||||
euler.y() = 0;
|
||||
euler.x() = std::atan2(pos_odom.y(), pos_odom.x());
|
||||
|
||||
const Eigen::Quaterniond ori = utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
|
||||
const auto target_ori = utils::transformOrientation(ori, T_odom_to_camera);
|
||||
const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix();
|
||||
|
||||
const cv::Mat rot_mat =
|
||||
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
|
||||
tf_rot(0, 1),
|
||||
tf_rot(0, 2),
|
||||
tf_rot(1, 0),
|
||||
tf_rot(1, 1),
|
||||
tf_rot(1, 2),
|
||||
tf_rot(2, 0),
|
||||
tf_rot(2, 1),
|
||||
tf_rot(2, 2));
|
||||
|
||||
cv::Mat rvec;
|
||||
cv::Rodrigues(rot_mat, rvec);
|
||||
|
||||
std::vector<cv::Point2f> pts_2d;
|
||||
cv::projectPoints(CAR_BOX, rvec, tvec, camera_intrinsic, camera_distortion, pts_2d);
|
||||
|
||||
const cv::Rect rect = cv::boundingRect(pts_2d);
|
||||
|
||||
const cv::Rect img_rect(0, 0, image_size.width, image_size.height);
|
||||
if ((rect & img_rect).area() <= 0) {
|
||||
return cv::Rect(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
const int base_side = std::max(rect.width, rect.height);
|
||||
const int max_side = std::max(image_size.width, image_size.height);
|
||||
|
||||
const double lost_dt = target_config_->lost_time_thres_param.get();
|
||||
const double dt_clamped = std::max(0.0, std::min(dt, lost_dt));
|
||||
|
||||
int side = static_cast<int>(base_side + (max_side - base_side) * (dt_clamped / lost_dt));
|
||||
|
||||
if (dt >= lost_dt) {
|
||||
side = max_side;
|
||||
}
|
||||
|
||||
const int cx = rect.x + rect.width / 2;
|
||||
const int cy = rect.y + rect.height / 2;
|
||||
cv::Rect square(cx - side / 2, cy - side / 2, side, side);
|
||||
|
||||
square &= img_rect;
|
||||
|
||||
return square;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, Armor>> Target::match(const std::vector<Armor>& armors) noexcept {
|
||||
std::vector<std::pair<int, Armor>> result;
|
||||
const int n_obs = static_cast<int>(armors.size());
|
||||
const int armors_num = armor_num_;
|
||||
const double GATE = target_config_->match_gate_param.get();
|
||||
const double max_cost = 1e9;
|
||||
std::vector<std::vector<double>> cost(n_obs, std::vector<double>(armors_num, max_cost + 1));
|
||||
std::vector<MModel::VecZ> meas_list(n_obs);
|
||||
for (int j = 0; j < n_obs; ++j) {
|
||||
meas_list[j] = getMeasure(armors[j]);
|
||||
}
|
||||
for (int j = 0; j < n_obs; ++j) {
|
||||
for (int id = 0; id < armors_num; ++id) {
|
||||
MModel::Measure::MeasureCtx tmp_ctx(id, armors_num);
|
||||
MModel::Measure measure(tmp_ctx);
|
||||
MModel::VecZ z_pred;
|
||||
measure.h(target_state_.x, z_pred);
|
||||
|
||||
MModel::VecZ nu = meas_list[j] - z_pred;
|
||||
nu[(int)MModel::MeasureID::YPD_Y] =
|
||||
angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_Y]);
|
||||
nu[(int)MModel::MeasureID::YPD_P] =
|
||||
angles::normalize_angle(nu[(int)MModel::MeasureID::YPD_P]);
|
||||
nu[(int)MModel::MeasureID::ORI_YAW] =
|
||||
angles::normalize_angle(nu[(int)MModel::MeasureID::ORI_YAW]);
|
||||
auto R = computeMeasurementCovariance(z_pred);
|
||||
double d2 = nu.transpose() * R.ldlt().solve(nu);
|
||||
|
||||
// 门控
|
||||
if (std::isfinite(d2) && d2 < GATE) {
|
||||
cost[j][id] = d2;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<bool> used_obs(n_obs, false);
|
||||
std::vector<bool> used_id(armors_num, false);
|
||||
|
||||
while (true) {
|
||||
double best = max_cost;
|
||||
int best_j = -1;
|
||||
int best_id = -1;
|
||||
|
||||
for (int j = 0; j < n_obs; ++j) {
|
||||
if (used_obs[j])
|
||||
continue;
|
||||
for (int id = 0; id < armors_num; ++id) {
|
||||
if (used_id[id])
|
||||
continue;
|
||||
if (cost[j][id] < best) {
|
||||
best = cost[j][id];
|
||||
best_j = j;
|
||||
best_id = id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (best_j < 0 || best_id < 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
used_obs[best_j] = true;
|
||||
used_id[best_id] = true;
|
||||
result.push_back(std::make_pair(best_id, armors[best_j]));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
185
wust_vision-main/tasks/auto_aim/armor_tracker/target.hpp
Normal file
185
wust_vision-main/tasks/auto_aim/armor_tracker/target.hpp
Normal file
@@ -0,0 +1,185 @@
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/armor_tracker/motion_models/motion_modelypdv2.hpp"
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
#include "wust_vl/common/utils/parameter.hpp"
|
||||
#include <wust_vl/common/utils/timer.hpp>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
namespace MModel = ypdv2armor_motion_model;
|
||||
|
||||
struct TargetConfig: wust_vl::common::utils::ParamGroup {
|
||||
static constexpr const char* kKey = "armor_tracker";
|
||||
const char* key() const override {
|
||||
return kKey;
|
||||
}
|
||||
GEN_PARAM(int, esekf_iter_num);
|
||||
GEN_PARAM(double, lost_time_thres);
|
||||
GEN_PARAM(int, tracking_thres);
|
||||
GEN_PARAM(double, max_yaw_diff_deg);
|
||||
GEN_PARAM(double, max_dis_diff);
|
||||
GEN_PARAM(double, match_gate);
|
||||
GEN_PARAM(double, qyaw_common);
|
||||
GEN_PARAM(double, qyaw_output);
|
||||
GEN_PARAM(double, q_r);
|
||||
GEN_PARAM(double, q_l);
|
||||
GEN_PARAM(double, q_h);
|
||||
GEN_PARAM(double, q_outpost_dz);
|
||||
GEN_PARAM(double, yp_r);
|
||||
GEN_PARAM(double, dis_r_front);
|
||||
GEN_PARAM(double, dis_r_side);
|
||||
GEN_PARAM(double, dis2_r_ratio);
|
||||
GEN_PARAM(double, yaw_r_base_front);
|
||||
GEN_PARAM(double, yaw_r_base_side);
|
||||
GEN_PARAM(double, yaw_r_log_ratio);
|
||||
GEN_PARAM(std::vector<double>, qxyz_common);
|
||||
GEN_PARAM(std::vector<double>, qxyz_output);
|
||||
Eigen::Vector3d qxyz_common = { 100, 100, 100 };
|
||||
Eigen::Vector3d qxyz_output = { 10, 10, 10 };
|
||||
using Ptr = std::shared_ptr<TargetConfig>;
|
||||
TargetConfig() {
|
||||
qxyz_output_param.onChange([this](auto o, auto n) {
|
||||
qxyz_common = Eigen::Vector3d(n[0], n[1], n[2]);
|
||||
});
|
||||
qxyz_output_param.onChange([this](auto o, auto n) {
|
||||
qxyz_output = Eigen::Vector3d(n[0], n[1], n[2]);
|
||||
});
|
||||
}
|
||||
static Ptr create() {
|
||||
return std::make_shared<TargetConfig>();
|
||||
}
|
||||
void loadSelf(const YAML::Node& node) override {
|
||||
esekf_iter_num_param.load(node);
|
||||
lost_time_thres_param.load(node);
|
||||
tracking_thres_param.load(node);
|
||||
max_yaw_diff_deg_param.load(node);
|
||||
max_dis_diff_param.load(node);
|
||||
match_gate_param.load(node);
|
||||
qyaw_common_param.load(node);
|
||||
qyaw_output_param.load(node);
|
||||
qxyz_common_param.load(node);
|
||||
qxyz_output_param.load(node);
|
||||
q_r_param.load(node);
|
||||
q_l_param.load(node);
|
||||
q_h_param.load(node);
|
||||
q_outpost_dz_param.load(node);
|
||||
yp_r_param.load(node);
|
||||
dis_r_front_param.load(node);
|
||||
dis_r_side_param.load(node);
|
||||
yaw_r_base_front_param.load(node);
|
||||
yaw_r_base_side_param.load(node);
|
||||
yaw_r_log_ratio_param.load(node);
|
||||
}
|
||||
};
|
||||
class Target {
|
||||
public:
|
||||
Target();
|
||||
Target(const Armor& armor, TargetConfig::Ptr target_config);
|
||||
MModel::Measure::MeasureCtx ctx_;
|
||||
ArmorNumber tracked_id_;
|
||||
std::string type_;
|
||||
MModel::VecZ measurement_ = Eigen::Matrix<double, MModel::Z_N, 1>::Zero();
|
||||
MModel::State target_state_ = MModel::State();
|
||||
double radius_pre_;
|
||||
|
||||
int armor_num_ = 4;
|
||||
bool jumped = false;
|
||||
bool is_inited = false;
|
||||
bool is_tracking = false;
|
||||
std::chrono::steady_clock::time_point last_t_;
|
||||
std::chrono::steady_clock::time_point timestamp_;
|
||||
MModel::RobotStateESEKF esekf_ypd_;
|
||||
TargetConfig::Ptr target_config_;
|
||||
[[nodiscard]] cv::Rect expanded(
|
||||
Eigen::Matrix4d T_camera_to_odom,
|
||||
const cv::Mat& camera_intrinsic,
|
||||
const cv::Mat& camera_distortion,
|
||||
const cv::Size& image_size
|
||||
) const noexcept;
|
||||
void predict(std::chrono::steady_clock::time_point t) noexcept;
|
||||
void predict(double dt) noexcept;
|
||||
void predictSimple(std::chrono::steady_clock::time_point t) noexcept;
|
||||
void predictSimple(double dt) noexcept;
|
||||
[[nodiscard]] MModel::Predict getPredictFunc(double dt) const noexcept;
|
||||
bool update(const std::pair<int, Armor>& armor) noexcept;
|
||||
[[nodiscard]] Eigen::Matrix<double, MModel::Z_N, MModel::Z_N>
|
||||
computeMeasurementCovariance(const Eigen::Matrix<double, MModel::Z_N, 1>& z) const noexcept;
|
||||
[[nodiscard]] Eigen::Matrix<double, MModel::X_N, MModel::X_N> computeProcessNoise(double dt
|
||||
) const noexcept;
|
||||
[[nodiscard]] std::optional<ArmorNumber> getArmorNumber() const noexcept {
|
||||
if (!checkTargetAppear()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return tracked_id_;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<double> getArmorYaws() const noexcept {
|
||||
std::vector<double> yaw_list;
|
||||
yaw_list.reserve(armor_num_);
|
||||
for (int i = 0; i < armor_num_; i++) {
|
||||
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
|
||||
MModel::Measure measure(_ctx);
|
||||
yaw_list.push_back(measure.get_angle(target_state_.x.data()));
|
||||
}
|
||||
return yaw_list;
|
||||
}
|
||||
[[nodiscard]] std::vector<Eigen::Vector3d> getArmorPositions() const noexcept {
|
||||
std::vector<Eigen::Vector3d> armor_positions;
|
||||
armor_positions.reserve(armor_num_);
|
||||
for (int i = 0; i < armor_num_; i++) {
|
||||
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
|
||||
MModel::Measure measure(_ctx);
|
||||
const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x);
|
||||
armor_positions.push_back(xyza.head<3>());
|
||||
}
|
||||
return armor_positions;
|
||||
}
|
||||
[[nodiscard]] std::vector<Eigen::Vector4d> getArmorPosAndYaw() const noexcept {
|
||||
std::vector<Eigen::Vector4d> pos_yaw;
|
||||
pos_yaw.reserve(armor_num_);
|
||||
for (int i = 0; i < armor_num_; ++i) {
|
||||
MModel::Measure::MeasureCtx _ctx(i, armor_num_);
|
||||
MModel::Measure measure(_ctx);
|
||||
const Eigen::Vector4d xyza = measure.h_armor_xyza(target_state_.x);
|
||||
pos_yaw.push_back(xyza);
|
||||
}
|
||||
return pos_yaw;
|
||||
}
|
||||
[[nodiscard]] double getMeanZ() const noexcept {
|
||||
double mean = 0;
|
||||
for (const auto& p: getArmorPositions()) {
|
||||
mean += p.z();
|
||||
}
|
||||
return mean / armor_num_;
|
||||
}
|
||||
[[nodiscard]] double getArmor2CenterXYDis(int id) const noexcept {
|
||||
const auto use_l_h = (armor_num_ == 4) && (id == 1 || id == 3);
|
||||
const auto r = (use_l_h) ? target_state_.r() + target_state_.l() : target_state_.r();
|
||||
return r;
|
||||
}
|
||||
[[nodiscard]] std::vector<std::pair<int, Armor>> match(const std::vector<Armor>& armors
|
||||
) noexcept;
|
||||
|
||||
[[nodiscard]] inline bool checkTargetAppear() const noexcept {
|
||||
const bool appear = is_tracking
|
||||
&& wust_vl::common::utils::time_utils::durationSec(
|
||||
timestamp_,
|
||||
wust_vl::common::utils::time_utils::now()
|
||||
) < target_config_->lost_time_thres_param.get();
|
||||
return appear;
|
||||
}
|
||||
|
||||
[[nodiscard]] Eigen::Matrix<double, MModel::Z_N, 1> getMeasure(const Armor& a) noexcept {
|
||||
const auto p = a.target_pos;
|
||||
const double measured_yaw = utils::orientationToYaw<Target>(a.target_ori);
|
||||
double ypd_y = std::atan2(p.y(), p.x());
|
||||
static double last_ypd_y = 0;
|
||||
ypd_y = last_ypd_y + angles::shortest_angular_distance(last_ypd_y, ypd_y);
|
||||
last_ypd_y = ypd_y;
|
||||
const double ypd_p = std::atan2(p.z(), std::sqrt(p.x() * p.x() + p.y() * p.y()));
|
||||
const double ypd_d = std::sqrt(p.x() * p.x() + p.y() * p.y() + p.z() * p.z());
|
||||
|
||||
return Eigen::Vector4d(ypd_y, ypd_p, ypd_d, measured_yaw);
|
||||
}
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
209
wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.cpp
Normal file
209
wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include "trackerv3.hpp"
|
||||
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
struct Tracker::Impl {
|
||||
public:
|
||||
enum State {
|
||||
LOST,
|
||||
DETECTING,
|
||||
TRACKING,
|
||||
TEMP_LOST,
|
||||
} tracker_state = LOST;
|
||||
Impl(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
|
||||
tracker_state = LOST;
|
||||
target_config_ = TargetConfig::create();
|
||||
auto_aim_config_parameter->registerGroup(*target_config_);
|
||||
auto_aim_config_parameter->reloadFromOldPath();
|
||||
target_ = Target();
|
||||
}
|
||||
|
||||
Target track(const Armors& armors_msg) noexcept {
|
||||
const double dt =
|
||||
std::chrono::duration<double>(armors_msg.timestamp - last_time_).count();
|
||||
last_time_ = armors_msg.timestamp;
|
||||
lost_thres_ =
|
||||
std::abs(static_cast<int>(target_config_->lost_time_thres_param.get() / dt));
|
||||
Armors armors;
|
||||
armors = armors_msg;
|
||||
std::erase_if(armors.armors, [this](const Armor& a) {
|
||||
double center_yaw =
|
||||
std::atan2(target_.target_state_.cy(), target_.target_state_.cx());
|
||||
bool state_check = tracker_state == TRACKING;
|
||||
bool outpost_check = target_.tracked_id_ == ArmorNumber::OUTPOST && !a.is_ok;
|
||||
bool pose_check =
|
||||
(std::abs(angles::normalize_angle(
|
||||
orientationToYaw(a.target_ori, center_yaw) - center_yaw
|
||||
)) > (target_config_->max_yaw_diff_deg_param.get() * M_PI / 180.0)
|
||||
|| std::abs((a.target_pos - target_.target_state_.pos()).norm())
|
||||
> target_config_->max_dis_diff_param.get())
|
||||
&& target_.is_inited
|
||||
&& std::abs(wust_vl::common::utils::time_utils::durationMs(
|
||||
target_.timestamp_,
|
||||
wust_vl::common::utils::time_utils::now()
|
||||
))
|
||||
< 1000.0;
|
||||
|
||||
return state_check && outpost_check || pose_check;
|
||||
});
|
||||
|
||||
std::sort(
|
||||
armors.armors.begin(),
|
||||
armors.armors.end(),
|
||||
[](const Armor& a, const Armor& b) {
|
||||
return a.distance_to_image_center < b.distance_to_image_center;
|
||||
}
|
||||
);
|
||||
bool found;
|
||||
if (tracker_state == LOST) {
|
||||
found = initTarget(armors);
|
||||
} else {
|
||||
found = updateTarget(armors);
|
||||
}
|
||||
updateFsm(found);
|
||||
|
||||
return target_;
|
||||
}
|
||||
void updateFsm(bool found) noexcept {
|
||||
switch (tracker_state) {
|
||||
case DETECTING:
|
||||
if (found) {
|
||||
if (++detect_count_ > target_config_->tracking_thres_param.get()) {
|
||||
detect_count_ = 0;
|
||||
tracker_state = TRACKING;
|
||||
}
|
||||
} else {
|
||||
detect_count_ = 0;
|
||||
tracker_state = LOST;
|
||||
}
|
||||
break;
|
||||
|
||||
case TRACKING:
|
||||
if (!found) {
|
||||
tracker_state = TEMP_LOST;
|
||||
lost_count_ = 1;
|
||||
}
|
||||
break;
|
||||
|
||||
case TEMP_LOST:
|
||||
if (!found) {
|
||||
if (++lost_count_ > lost_thres_) {
|
||||
lost_count_ = 0;
|
||||
tracker_state = LOST;
|
||||
}
|
||||
} else {
|
||||
lost_count_ = 0;
|
||||
tracker_state = TRACKING;
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
target_.is_tracking = (tracker_state == TRACKING || tracker_state == TEMP_LOST);
|
||||
|
||||
if (found)
|
||||
++found_count_;
|
||||
}
|
||||
|
||||
bool initTarget(const Armors& armors) noexcept {
|
||||
if (armors.armors.empty()) {
|
||||
return false;
|
||||
}
|
||||
bool found = false;
|
||||
Armor init_target;
|
||||
Armors others = armors;
|
||||
others.armors.clear();
|
||||
for (auto& a: armors.armors) {
|
||||
if (!a.is_none_purple && !found) {
|
||||
init_target = a;
|
||||
found = true;
|
||||
continue;
|
||||
}
|
||||
others.armors.push_back(a);
|
||||
}
|
||||
if (!found) {
|
||||
return false;
|
||||
}
|
||||
target_ = Target(init_target, target_config_);
|
||||
// updateTarget(others);
|
||||
tracker_state = DETECTING;
|
||||
return true;
|
||||
}
|
||||
bool updateTarget(const Armors& armors) noexcept {
|
||||
if (armors.armors.empty())
|
||||
return false;
|
||||
|
||||
target_.predict(armors.timestamp);
|
||||
std::vector<Armor> candidates;
|
||||
candidates.reserve(armors.armors.size());
|
||||
|
||||
for (const auto& a: armors.armors) {
|
||||
if (isSameTarget(a.number, target_.tracked_id_) && !a.is_none_purple) {
|
||||
candidates.emplace_back(a);
|
||||
}
|
||||
}
|
||||
|
||||
if (candidates.empty())
|
||||
return false;
|
||||
|
||||
int updated = 0;
|
||||
const auto matches = target_.match(candidates);
|
||||
|
||||
for (const auto& m: matches) {
|
||||
if (m.second.is_none_purple) {
|
||||
if (++is_none_purple_count_ > 100)
|
||||
continue;
|
||||
} else {
|
||||
is_none_purple_count_ = 0;
|
||||
}
|
||||
|
||||
if (target_.update(m))
|
||||
++updated;
|
||||
}
|
||||
|
||||
return updated > 0;
|
||||
}
|
||||
|
||||
int lost_thres_;
|
||||
int detect_count_ = 0;
|
||||
int lost_count_ = 0;
|
||||
int is_none_purple_count_ = 0;
|
||||
int found_count_ = 0;
|
||||
Target target_;
|
||||
std::chrono::steady_clock::time_point last_time_;
|
||||
TargetConfig::Ptr target_config_;
|
||||
|
||||
double orientationToYaw(const Eigen::Quaterniond& q, double from) noexcept {
|
||||
double roll, pitch, yaw;
|
||||
Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false);
|
||||
yaw = euler[0];
|
||||
yaw = from + angles::shortest_angular_distance(from, yaw);
|
||||
return yaw;
|
||||
}
|
||||
};
|
||||
Tracker::Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
|
||||
_impl = std::make_unique<Impl>(auto_aim_config_parameter);
|
||||
}
|
||||
Tracker::~Tracker() {
|
||||
_impl.reset();
|
||||
}
|
||||
Target Tracker::track(const Armors& armors) noexcept {
|
||||
return _impl->track(armors);
|
||||
}
|
||||
int Tracker::getFoundCount() const noexcept {
|
||||
return _impl->found_count_;
|
||||
}
|
||||
void Tracker::setFoundCount(int count) noexcept {
|
||||
_impl->found_count_ = count;
|
||||
}
|
||||
std::chrono::steady_clock::time_point Tracker::getLastTime() const noexcept {
|
||||
return _impl->last_time_;
|
||||
}
|
||||
void Tracker::setLastTime(std::chrono::steady_clock::time_point t) noexcept {
|
||||
_impl->last_time_ = t;
|
||||
}
|
||||
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
23
wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.hpp
Normal file
23
wust_vision-main/tasks/auto_aim/armor_tracker/trackerv3.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "target.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class Tracker {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<Tracker>;
|
||||
Tracker(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter);
|
||||
static Ptr create(wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter) {
|
||||
return std::make_unique<Tracker>(auto_aim_config_parameter);
|
||||
}
|
||||
~Tracker();
|
||||
[[nodiscard]] Target track(const Armors& armors) noexcept;
|
||||
int getFoundCount() const noexcept;
|
||||
void setFoundCount(int count) noexcept;
|
||||
void setLastTime(std::chrono::steady_clock::time_point t) noexcept;
|
||||
std::chrono::steady_clock::time_point getLastTime() const noexcept;
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
307
wust_vision-main/tasks/auto_aim/armor_where/armor_where.cpp
Normal file
307
wust_vision-main/tasks/auto_aim/armor_where/armor_where.cpp
Normal file
@@ -0,0 +1,307 @@
|
||||
// Created by Labor 2023.8.25
|
||||
// Maintained by Chengfu Zou, Labor
|
||||
// Copyright (C) FYT Vision Group. All rights reserved.
|
||||
// Copyright 2025 Xiaojian Wu
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "armor_where.hpp"
|
||||
#include "wust_vl/algorithm/pnp_solver.hpp"
|
||||
#include <opencv2/core/eigen.hpp>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
|
||||
struct ArmorWhere::Impl {
|
||||
public:
|
||||
Impl(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
|
||||
camera_info_ = camera_info;
|
||||
params_.load(config);
|
||||
pnp_solver_ = std::make_unique<wust_vl::algorithm::PnPSolver>(cv::SOLVEPNP_IPPE);
|
||||
pnp_solver_->setObjectPoints(
|
||||
"small",
|
||||
ArmorObject::buildObjectPoints<cv::Point3f>(SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT)
|
||||
);
|
||||
pnp_solver_->setObjectPoints(
|
||||
"large",
|
||||
ArmorObject::buildObjectPoints<cv::Point3f>(LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT)
|
||||
);
|
||||
}
|
||||
struct Params {
|
||||
enum class OptMode : int { GOLDEN = 0, CERES = 1, NONE = 2 } opt_mode;
|
||||
OptMode fromString(const std::string& mode) {
|
||||
if (mode == "golden" || mode == "GOLDEN") {
|
||||
return OptMode::GOLDEN;
|
||||
} else if (mode == "none" || mode == "NONE") {
|
||||
return OptMode::NONE;
|
||||
} else {
|
||||
return OptMode::NONE;
|
||||
}
|
||||
}
|
||||
int golden_search_side_deg = 60;
|
||||
double distance_fix_a2 = 0;
|
||||
void load(const YAML::Node& node) {
|
||||
opt_mode = fromString(node["yaw_opt"]["mode"].as<std::string>());
|
||||
golden_search_side_deg = node["yaw_opt"]["golden_search_side_deg"].as<int>();
|
||||
distance_fix_a2 = node["distance_fix_a2"].as<double>();
|
||||
}
|
||||
} params_;
|
||||
|
||||
std::vector<Armor> where(
|
||||
const std::vector<ArmorObject>& armors,
|
||||
Eigen::Matrix4d T_camera_to_odom
|
||||
) const noexcept {
|
||||
std::vector<Armor> armors_msg;
|
||||
|
||||
const Eigen::Matrix3d R_imu_cam = T_camera_to_odom.block<3, 3>(0, 0);
|
||||
auto makeArmor =
|
||||
[&](const ArmorObject& obj, const Eigen::Vector3d& t, const Eigen::Matrix3d& R) {
|
||||
Armor msg;
|
||||
msg.type = (obj.number == ArmorNumber::NO1 || obj.number == ArmorNumber::BASE)
|
||||
? "large"
|
||||
: "small";
|
||||
msg.number = obj.number;
|
||||
Eigen::Quaterniond q(R);
|
||||
Eigen::Quaterniond add_roll {
|
||||
Eigen::AngleAxisd(M_PI / 2, Eigen::Vector3d::UnitX())
|
||||
};
|
||||
Eigen::Quaterniond new_q = q * add_roll;
|
||||
|
||||
auto [yaw, pitch, dist] = utils::xyz2ypd_rad(t.x(), t.y(), t.z());
|
||||
dist += params_.distance_fix_a2 * dist * dist;
|
||||
auto [x, y, z] = utils::ypd2xyz_rad(yaw, pitch, dist);
|
||||
|
||||
msg.pos = { x, y, z };
|
||||
msg.ori = new_q;
|
||||
auto_aim::transformArmorData(msg, T_camera_to_odom);
|
||||
msg.distance_to_image_center =
|
||||
pnp_solver_->calculateDistanceToCenter(obj.center, camera_info_.first);
|
||||
msg.is_ok = obj.is_ok;
|
||||
if (obj.color == ArmorColor::NONE || obj.color == ArmorColor::PURPLE) {
|
||||
msg.is_none_purple = true;
|
||||
} else {
|
||||
msg.is_none_purple = false;
|
||||
}
|
||||
return msg;
|
||||
};
|
||||
|
||||
for (auto const& a: armors) {
|
||||
cv::Mat rvec, tvec;
|
||||
std::string type = (a.number == ArmorNumber::NO1 || a.number == ArmorNumber::BASE)
|
||||
? "large"
|
||||
: "small";
|
||||
|
||||
if (!pnp_solver_->solvePnP(
|
||||
a.landmarks(),
|
||||
rvec,
|
||||
tvec,
|
||||
type,
|
||||
camera_info_.first,
|
||||
camera_info_.second
|
||||
))
|
||||
{
|
||||
WUST_WARN("PNP") << "PNP failed";
|
||||
continue;
|
||||
}
|
||||
cv::Mat R_cv;
|
||||
cv::Rodrigues(rvec, R_cv);
|
||||
Eigen::Matrix3d R = utils::cvToEigen(R_cv);
|
||||
Eigen::Vector3d t = utils::cvToEigen(tvec);
|
||||
if (params_.opt_mode != Params::OptMode::NONE) {
|
||||
Eigen::Matrix3d R0 = R;
|
||||
R = solveBa_R(a, t, R0, R_imu_cam, type);
|
||||
}
|
||||
|
||||
armors_msg.push_back(makeArmor(a, t, R));
|
||||
}
|
||||
|
||||
return armors_msg;
|
||||
}
|
||||
std::vector<Eigen::Vector2d> reprojectionArmor(
|
||||
double yaw,
|
||||
const std::vector<cv::Point3f>& object_points,
|
||||
const std::vector<cv::Point2f>& landmarks,
|
||||
const Eigen::Matrix3d& Rci,
|
||||
double pitch,
|
||||
double roll,
|
||||
const Eigen::Vector3d& t
|
||||
) const noexcept {
|
||||
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
|
||||
const Eigen::AngleAxisd ap(pitch, Eigen::Vector3d::UnitY());
|
||||
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
|
||||
const Eigen::Matrix3d R = Rci * (ay * ap * ar).toRotationMatrix();
|
||||
|
||||
cv::Mat rvec, R_cv;
|
||||
cv::eigen2cv(R, R_cv);
|
||||
cv::Rodrigues(R_cv, rvec);
|
||||
|
||||
const cv::Mat tvec = (cv::Mat_<double>(3, 1) << t.x(), t.y(), t.z());
|
||||
|
||||
std::vector<cv::Point2f> pts_2d;
|
||||
pts_2d.reserve(object_points.size());
|
||||
cv::projectPoints(
|
||||
object_points,
|
||||
rvec,
|
||||
tvec,
|
||||
camera_info_.first,
|
||||
camera_info_.second,
|
||||
pts_2d
|
||||
);
|
||||
|
||||
std::vector<Eigen::Vector2d> image_points;
|
||||
image_points.reserve(pts_2d.size());
|
||||
|
||||
for (const auto& p: pts_2d) {
|
||||
image_points.emplace_back(p.x, p.y);
|
||||
}
|
||||
|
||||
return image_points;
|
||||
}
|
||||
|
||||
double reprojectionErrorYaw(
|
||||
double yaw,
|
||||
const std::vector<cv::Point3f>& object_points,
|
||||
const std::vector<cv::Point2f>& landmarks,
|
||||
const std::vector<std::pair<int, int>>& sym_pairs,
|
||||
const Eigen::Matrix3d& Rci,
|
||||
double pitch,
|
||||
double roll,
|
||||
const Eigen::Vector3d& t
|
||||
) const noexcept {
|
||||
const auto image_points =
|
||||
reprojectionArmor(yaw, object_points, landmarks, Rci, pitch, roll, t);
|
||||
double cost = 0.0;
|
||||
|
||||
// for (size_t i = 0; i < image_points.size(); ++i) {
|
||||
// Eigen::Vector2d obs(landmarks[i].x, landmarks[i].y);
|
||||
// cost += (image_points[i] - obs).squaredNorm();
|
||||
// }
|
||||
|
||||
for (auto& p: sym_pairs) {
|
||||
const Eigen::Vector2d mid = 0.5 * (image_points[p.first] + image_points[p.second]);
|
||||
|
||||
const Eigen::Vector2d meas = 0.5
|
||||
* (Eigen::Vector2d(landmarks[p.first].x, landmarks[p.first].y)
|
||||
+ Eigen::Vector2d(landmarks[p.second].x, landmarks[p.second].y));
|
||||
|
||||
cost += (mid - meas).squaredNorm();
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
|
||||
double goldenYaw(
|
||||
double init,
|
||||
const std::vector<cv::Point3f>& obj,
|
||||
const std::vector<cv::Point2f>& lm,
|
||||
const std::vector<std::pair<int, int>>& sym_pairs,
|
||||
const Eigen::Matrix3d& Rci,
|
||||
double pitch,
|
||||
double roll,
|
||||
const Eigen::Vector3d& t
|
||||
) const noexcept {
|
||||
constexpr double phi = 1.618033988749894848; //(1.0 + std::sqrt(5.0)) * 0.5;
|
||||
double l = init - params_.golden_search_side_deg * M_PI / 180.0;
|
||||
double r = init + params_.golden_search_side_deg * M_PI / 180.0;
|
||||
|
||||
double y1 = r - (r - l) / phi;
|
||||
double y2 = l + (r - l) / phi;
|
||||
|
||||
double f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t);
|
||||
double f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t);
|
||||
|
||||
while (r - l > 0.0001) {
|
||||
if (f1 < f2) {
|
||||
r = y2;
|
||||
y2 = y1;
|
||||
f2 = f1;
|
||||
y1 = r - (r - l) / phi;
|
||||
f1 = reprojectionErrorYaw(y1, obj, lm, sym_pairs, Rci, pitch, roll, t);
|
||||
} else {
|
||||
l = y1;
|
||||
y1 = y2;
|
||||
f1 = f2;
|
||||
y2 = l + (r - l) / phi;
|
||||
f2 = reprojectionErrorYaw(y2, obj, lm, sym_pairs, Rci, pitch, roll, t);
|
||||
}
|
||||
}
|
||||
|
||||
return 0.5 * (l + r);
|
||||
}
|
||||
|
||||
Eigen::Matrix3d solveBa_R(
|
||||
const ArmorObject& armor,
|
||||
const Eigen::Vector3d& t_camera_armor,
|
||||
const Eigen::Matrix3d& R_camera_armor,
|
||||
const Eigen::Matrix3d& R_imu_camera,
|
||||
const std::string& type
|
||||
) const noexcept {
|
||||
const Eigen::Matrix3d R_imu_armor = R_imu_camera * R_camera_armor;
|
||||
const Eigen::Matrix3d R_camera_imu = R_imu_camera.transpose();
|
||||
//double roll = std::atan2(R_imu_armor(2, 1), R_imu_armor(2, 2));
|
||||
const double roll = 0;
|
||||
// initial yaw
|
||||
const double yaw_init = std::atan2(-R_imu_armor(0, 1), R_imu_armor(1, 1));
|
||||
|
||||
const double armor_pitch =
|
||||
(armor.number == ArmorNumber::OUTPOST) ? -FIFTTEN_DEGREE_RAD : FIFTTEN_DEGREE_RAD;
|
||||
|
||||
const Eigen::Vector2d armor_size = (type == "large")
|
||||
? Eigen::Vector2d { LARGE_ARMOR_WIDTH, LARGE_ARMOR_HEIGHT }
|
||||
: Eigen::Vector2d { SMALL_ARMOR_WIDTH, SMALL_ARMOR_HEIGHT };
|
||||
|
||||
const auto objPts =
|
||||
ArmorObject::buildObjectPoints<cv::Point3f>(armor_size.x(), armor_size.y());
|
||||
const auto& lm = armor.landmarks();
|
||||
const auto& sym_pairs = ArmorObject::buildSymPairs<int>();
|
||||
double yaw = yaw_init;
|
||||
if (params_.opt_mode == Params::OptMode::GOLDEN) {
|
||||
yaw = goldenYaw(
|
||||
yaw_init,
|
||||
objPts,
|
||||
lm,
|
||||
sym_pairs,
|
||||
R_camera_imu,
|
||||
armor_pitch,
|
||||
roll,
|
||||
t_camera_armor
|
||||
);
|
||||
}
|
||||
|
||||
const Eigen::AngleAxisd ay(yaw, Eigen::Vector3d::UnitZ());
|
||||
const Eigen::AngleAxisd ap(armor_pitch, Eigen::Vector3d::UnitY());
|
||||
const Eigen::AngleAxisd ar(roll, Eigen::Vector3d::UnitX());
|
||||
|
||||
const Eigen::Matrix3d R_result = R_camera_imu * (ay * ap * ar).toRotationMatrix();
|
||||
return R_result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<cv::Mat, cv::Mat> camera_info_;
|
||||
std::unique_ptr<wust_vl::algorithm::PnPSolver> pnp_solver_;
|
||||
};
|
||||
ArmorWhere::ArmorWhere(
|
||||
const YAML::Node& config,
|
||||
const std::pair<cv::Mat, cv::Mat>& camera_info
|
||||
) {
|
||||
_impl = std::make_unique<Impl>(config, camera_info);
|
||||
}
|
||||
ArmorWhere::~ArmorWhere() {
|
||||
_impl.reset();
|
||||
}
|
||||
std::vector<Armor> ArmorWhere::where(
|
||||
const std::vector<ArmorObject>& armors,
|
||||
Eigen::Matrix4d T_camera_to_odom
|
||||
) const noexcept {
|
||||
return _impl->where(armors, T_camera_to_odom);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
26
wust_vision-main/tasks/auto_aim/armor_where/armor_where.hpp
Normal file
26
wust_vision-main/tasks/auto_aim/armor_where/armor_where.hpp
Normal file
@@ -0,0 +1,26 @@
|
||||
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
class ArmorWhere {
|
||||
public:
|
||||
using Ptr = std::unique_ptr<ArmorWhere>;
|
||||
ArmorWhere(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info);
|
||||
static Ptr
|
||||
create(const YAML::Node& config, const std::pair<cv::Mat, cv::Mat>& camera_info) {
|
||||
return std::make_unique<ArmorWhere>(config, camera_info);
|
||||
}
|
||||
~ArmorWhere();
|
||||
std::vector<Armor> where(
|
||||
const std::vector<ArmorObject>& armors,
|
||||
Eigen::Matrix4d T_camera_to_odom
|
||||
) const noexcept;
|
||||
|
||||
private:
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
404
wust_vision-main/tasks/auto_aim/auto_aim.cpp
Normal file
404
wust_vision-main/tasks/auto_aim/auto_aim.cpp
Normal file
@@ -0,0 +1,404 @@
|
||||
#include "auto_aim.hpp"
|
||||
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
|
||||
#include "tasks/auto_aim/armor_detect/armor_detector_factory.hpp"
|
||||
#include "tasks/auto_aim/armor_tracker/target.hpp"
|
||||
#include "tasks/auto_aim/armor_tracker/trackerv3.hpp"
|
||||
#include "tasks/auto_aim/armor_where/armor_where.hpp"
|
||||
#include "tasks/auto_aim/debug.hpp"
|
||||
#include "tasks/type_common.hpp"
|
||||
#include "tasks/utils/config.hpp"
|
||||
#include "wust_vl/common/concurrency/queues.hpp"
|
||||
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
|
||||
struct AutoAim::Impl {
|
||||
~Impl() {
|
||||
run_flag_ = false;
|
||||
if (armor_detector_) {
|
||||
armor_detector_.reset();
|
||||
}
|
||||
armor_queue_->stop();
|
||||
if (processing_thread_) {
|
||||
processing_thread_->stop();
|
||||
wust_vl::common::concurrency::ThreadManager::instance().unregisterThread(
|
||||
processing_thread_->getName()
|
||||
);
|
||||
}
|
||||
}
|
||||
Impl(
|
||||
const std::string& config_path,
|
||||
TFConfig::Ptr tf_config,
|
||||
const std::pair<cv::Mat, cv::Mat>& camera_info,
|
||||
bool debug
|
||||
) {
|
||||
tf_config_ = tf_config;
|
||||
camera_info_ = camera_info;
|
||||
auto_aim_config_parameter_ = wust_vl::common::utils::Parameter::create();
|
||||
auto config = YAML::LoadFile(config_path);
|
||||
auto_aim_config_parameter_->loadFromFile(config_path);
|
||||
auto_exposure_cfg_ = AutoExposureCfg::create();
|
||||
very_aimer_ = VeryAimer::create(auto_aim_config_parameter_);
|
||||
auto_aim_config_parameter_->registerGroup(*auto_exposure_cfg_);
|
||||
auto_aim_config_parameter_->registerGroup(*auto_aim_fsm_cl_.config_);
|
||||
tracker_ = Tracker::create(auto_aim_config_parameter_);
|
||||
auto_aim_config_parameter_->reloadFromOldPath();
|
||||
|
||||
wust_vl::common::utils::ParameterManager::instance().registerParameter(
|
||||
*auto_aim_config_parameter_.get()
|
||||
);
|
||||
max_detect_armors_ = config["max_detect_armors"].as<int>(10);
|
||||
armor_where_ = ArmorWhere::create(config["armor_where"], camera_info_);
|
||||
const std::string armor_detect_backend =
|
||||
config["armor_detect_backend"].as<std::string>("");
|
||||
armor_detector_ = DetectorFactory::createArmorDetector(armor_detect_backend, true);
|
||||
armor_detector_->setCallback(std::bind(
|
||||
&AutoAim::Impl::ArmorDetectCallback,
|
||||
this,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2
|
||||
));
|
||||
WUST_MAIN(logger_) << "Using Armor Detector: " << armor_detect_backend;
|
||||
armor_queue_ =
|
||||
std::make_unique<wust_vl::common::concurrency::OrderedQueue<Armors>>(50, 500);
|
||||
latency_averager_ =
|
||||
std::make_unique<wust_vl::common::concurrency::Averager<double>>(100);
|
||||
}
|
||||
void start() {
|
||||
if (run_flag_) {
|
||||
return;
|
||||
}
|
||||
run_flag_ = true;
|
||||
processing_thread_ = wust_vl::common::concurrency::MonitoredThread::create(
|
||||
"AutoAimProcessingThread",
|
||||
[this](wust_vl::common::concurrency::MonitoredThread::Ptr self) {
|
||||
this->processingLoop(self);
|
||||
}
|
||||
);
|
||||
wust_vl::common::concurrency::ThreadManager::instance().registerThread(
|
||||
processing_thread_
|
||||
);
|
||||
}
|
||||
void pushInput(CommonFrame& frame) {
|
||||
img_recv_count_++;
|
||||
|
||||
auto bbox = target_.expanded(
|
||||
T_camera_to_odom_,
|
||||
camera_info_.first,
|
||||
camera_info_.second,
|
||||
frame.img_frame.src_img.size()
|
||||
);
|
||||
if (bbox.area() > 100) {
|
||||
frame.expanded = bbox;
|
||||
frame.offset = cv::Point2f(bbox.x, bbox.y);
|
||||
}
|
||||
expanded_ = frame.expanded;
|
||||
const std::optional<ArmorNumber> target_number = target_.getArmorNumber();
|
||||
if (armor_detector_) {
|
||||
armor_detector_->pushInput(frame, target_number);
|
||||
}
|
||||
}
|
||||
|
||||
void ArmorDetectCallback(const std::vector<ArmorObject>& objs, const CommonFrame& frame) {
|
||||
std::vector<ArmorObject> sorted_objs = objs;
|
||||
|
||||
if (sorted_objs.size() > max_detect_armors_) [[unlikely]] {
|
||||
WUST_WARN(logger_) << "Detected " << sorted_objs.size()
|
||||
<< " objects, too many, keeping top " << max_detect_armors_;
|
||||
|
||||
std::partial_sort(
|
||||
sorted_objs.begin(),
|
||||
sorted_objs.begin() + max_detect_armors_,
|
||||
sorted_objs.end(),
|
||||
[](const ArmorObject& a, const ArmorObject& b) {
|
||||
return a.confidence > b.confidence;
|
||||
}
|
||||
);
|
||||
|
||||
sorted_objs.resize(max_detect_armors_);
|
||||
}
|
||||
for (auto& obj: sorted_objs) {
|
||||
obj.addOffset(frame.offset);
|
||||
}
|
||||
|
||||
Armors armors;
|
||||
armors.timestamp = frame.img_frame.timestamp;
|
||||
armors.id = frame.id;
|
||||
Eigen::Vector3d v = Eigen::Vector3d::Zero();
|
||||
Eigen::Matrix3d R_gimbal2odom = Eigen::Matrix3d::Identity();
|
||||
auto ctx = std::any_cast<VisionCtx>(frame.any_ctx);
|
||||
std::pair<double, double> gimbal_py;
|
||||
if (ctx.motion_buffer) {
|
||||
const auto delay =
|
||||
std::chrono::microseconds(static_cast<int64_t>(ctx.communication_delay_μs));
|
||||
const auto t_query = armors.timestamp + delay;
|
||||
auto apply_motion = [&](const auto& att) {
|
||||
v << att.data.vx, att.data.vy, att.data.vz;
|
||||
R_gimbal2odom = Eigen::AngleAxisd(att.data.yaw, Eigen::Vector3d::UnitZ())
|
||||
* Eigen::AngleAxisd(-att.data.pitch, Eigen::Vector3d::UnitY())
|
||||
* Eigen::AngleAxisd(att.data.roll, Eigen::Vector3d::UnitX());
|
||||
gimbal_py = std::make_pair(att.data.pitch, att.data.yaw);
|
||||
};
|
||||
if (auto past_att = ctx.motion_buffer->get_interpolated(t_query)) {
|
||||
apply_motion(*past_att);
|
||||
} else if (auto last_att = ctx.motion_buffer->get_last()) {
|
||||
apply_motion(*last_att);
|
||||
}
|
||||
}
|
||||
autoExposureControl(frame.img_frame.src_img, ctx.camera);
|
||||
T_camera_to_odom_ = utils::computeCameraToOdomTransform(
|
||||
R_gimbal2odom,
|
||||
tf_config_->R_camera2gimbal,
|
||||
tf_config_->t_camera2gimbal
|
||||
);
|
||||
armors.armors = armor_where_->where(sorted_objs, T_camera_to_odom_);
|
||||
armors.v = v;
|
||||
for (auto& armor: armors.armors) {
|
||||
armor.timestamp = armors.timestamp;
|
||||
}
|
||||
|
||||
armor_queue_->enqueue(armors);
|
||||
++detect_finish_count_;
|
||||
if (debug_mode_) {
|
||||
std::lock_guard<std::mutex> lock(dbg_mutex_);
|
||||
auto& dbg = auto_aim_debug_;
|
||||
|
||||
dbg.img_frame = frame.img_frame;
|
||||
dbg.armors = armors;
|
||||
dbg.T_camera_to_odom = T_camera_to_odom_;
|
||||
dbg.detect_color = frame.detect_color;
|
||||
dbg.armor_objs = sorted_objs;
|
||||
dbg.expanded = frame.expanded;
|
||||
dbg.gimbal_py = gimbal_py;
|
||||
}
|
||||
}
|
||||
void armorsCallback(const Armors& armors) {
|
||||
if (armors.timestamp <= tracker_->getLastTime()) {
|
||||
WUST_WARN(logger_) << "Received out-of-order armor data, discarded.";
|
||||
return;
|
||||
}
|
||||
Target target = tracker_->track(armors);
|
||||
auto_aim_fsm_cl_.update(std::abs(target.target_state_.vyaw()), target.jumped);
|
||||
const auto now = std::chrono::steady_clock::now();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(target_mutex_);
|
||||
target_ = target;
|
||||
}
|
||||
const auto latency_ms =
|
||||
wust_vl::common::utils::time_utils::durationMs(armors.timestamp, now);
|
||||
latency_averager_->add(latency_ms);
|
||||
auto& dbg = auto_aim_debug_;
|
||||
dbg.latency_ms = latency_averager_->average();
|
||||
if (debug_mode_) {
|
||||
std::lock_guard<std::mutex> lock(dbg_mutex_);
|
||||
dbg.target = target;
|
||||
dbg.fsm = auto_aim_fsm_cl_.fsm_state_;
|
||||
}
|
||||
}
|
||||
Target getTarget() {
|
||||
Target target;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(target_mutex_);
|
||||
target = target_;
|
||||
}
|
||||
return target;
|
||||
}
|
||||
GimbalCmd solve(double bullet_speed) {
|
||||
GimbalCmd gimbal_cmd;
|
||||
Target target;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(target_mutex_);
|
||||
target = target_;
|
||||
}
|
||||
AimTarget aim_target;
|
||||
const bool appear = target.checkTargetAppear();
|
||||
if (appear && target.target_state_.pos().norm() > 0.1) {
|
||||
try {
|
||||
gimbal_cmd =
|
||||
very_aimer_->veryAim(target, bullet_speed, auto_aim_fsm_cl_.fsm_state_);
|
||||
aim_target = gimbal_cmd.aim_target;
|
||||
} catch (...) {
|
||||
WUST_ERROR(logger_) << "VeryAim error";
|
||||
}
|
||||
}
|
||||
if (gimbal_cmd.fire_advice) {
|
||||
fire_count_++;
|
||||
}
|
||||
if (debug_mode_) {
|
||||
std::lock_guard<std::mutex> lock(dbg_mutex_);
|
||||
auto_aim_debug_.gimbal_cmd = gimbal_cmd;
|
||||
auto_aim_debug_.aim_target = aim_target;
|
||||
}
|
||||
timer_cout_++;
|
||||
return gimbal_cmd;
|
||||
}
|
||||
void processingLoop(wust_vl::common::concurrency::MonitoredThread::Ptr self) {
|
||||
while (!self->isAlive()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||
}
|
||||
while (self->isAlive() && run_flag_) {
|
||||
if (!self->waitPoint())
|
||||
break;
|
||||
self->heartbeat();
|
||||
printStats();
|
||||
Armors armors;
|
||||
// bool skip;
|
||||
// if (armor_queue_->dequeue_wait(armors, skip)) {
|
||||
// armorsCallback(armors);
|
||||
// tracker_finish_count_++;
|
||||
// if (skip) {
|
||||
// WUST_DEBUG(logger_) << "OrderQueue skip";
|
||||
// }
|
||||
// }
|
||||
if (armor_queue_->try_dequeue(armors)) {
|
||||
armorsCallback(armors);
|
||||
tracker_finish_count_++;
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
void doDebug() {
|
||||
debug_mode_ = true;
|
||||
AutoAimDebug dbg;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(dbg_mutex_);
|
||||
dbg = auto_aim_debug_;
|
||||
}
|
||||
drawDebugOverlayShm(dbg, camera_info_, false);
|
||||
debuglog(dbg);
|
||||
}
|
||||
|
||||
void printStats() {
|
||||
utils::XSecOnce(
|
||||
[&] {
|
||||
double found_ratio = 0.0;
|
||||
if (img_recv_count_ > 0) {
|
||||
found_ratio = static_cast<double>(tracker_->getFoundCount())
|
||||
/ static_cast<double>(img_recv_count_);
|
||||
}
|
||||
|
||||
WUST_INFO(logger_)
|
||||
<< "Rec: " << img_recv_count_ << ", Det: " << detect_finish_count_
|
||||
<< ", Fin: " << tracker_finish_count_ << ", Tc: " << timer_cout_
|
||||
<< ", Lat: " << auto_aim_debug_.latency_ms << "ms"
|
||||
<< ", Fire: " << fire_count_ << ", Found: " << tracker_->getFoundCount()
|
||||
<< ", Found_ratio: " << found_ratio;
|
||||
|
||||
img_recv_count_ = 0;
|
||||
detect_finish_count_ = 0;
|
||||
fire_count_ = 0;
|
||||
tracker_finish_count_ = 0;
|
||||
timer_cout_ = 0;
|
||||
tracker_->setFoundCount(0);
|
||||
},
|
||||
1.0
|
||||
);
|
||||
}
|
||||
void
|
||||
autoExposureControl(const cv::Mat& frame, std::shared_ptr<wust_vl::video::Camera> camera) {
|
||||
const double dt = auto_exposure_cfg_->control_interval_ms_param.get() / 1000.0;
|
||||
utils::XSecOnce(
|
||||
[&] {
|
||||
if (!auto_exposure_cfg_->enable_param.get() || frame.empty()) {
|
||||
return;
|
||||
}
|
||||
if (auto* hik = dynamic_cast<wust_vl::video::HikCamera*>(camera->getDevice())) {
|
||||
cv::Mat i_use = frame(expanded_);
|
||||
if (expanded_.area() < 100 || i_use.empty()) {
|
||||
i_use = frame;
|
||||
}
|
||||
const double brightness = utils::computeBrightness(i_use);
|
||||
|
||||
const double diff =
|
||||
brightness - auto_exposure_cfg_->target_brightness_param.get();
|
||||
const double exposure_min = auto_exposure_cfg_->exposure_min_param.get();
|
||||
const double exposure_max = auto_exposure_cfg_->exposure_max_param.get();
|
||||
double exposure_time = hik->getExposureTime();
|
||||
const double last_exposure_time = exposure_time;
|
||||
if (std::fabs(diff) > auto_exposure_cfg_->tolerance_param.get()
|
||||
&& exposure_time > 0.0) {
|
||||
exposure_time -= diff * auto_exposure_cfg_->step_gain_param.get();
|
||||
} else {
|
||||
exposure_time -= auto_exposure_cfg_->decay_step_param.get();
|
||||
}
|
||||
if (exposure_time < exposure_min)
|
||||
exposure_time = exposure_min;
|
||||
if (exposure_time > exposure_max)
|
||||
exposure_time = exposure_max;
|
||||
if (std::abs(exposure_time - last_exposure_time) > 10) {
|
||||
hik->setExposureTime(exposure_time);
|
||||
}
|
||||
}
|
||||
},
|
||||
dt
|
||||
);
|
||||
}
|
||||
wust_vl::common::utils::Parameter::Ptr auto_aim_config_parameter_;
|
||||
Tracker::Ptr tracker_;
|
||||
ArmorDetectorBase::Ptr armor_detector_;
|
||||
std::string logger_ = "auto_aim";
|
||||
std::unique_ptr<wust_vl::common::concurrency::OrderedQueue<Armors>> armor_queue_;
|
||||
wust_vl::common::concurrency::MonitoredThread::Ptr processing_thread_;
|
||||
std::unique_ptr<wust_vl::common::utils::Timer> timer_;
|
||||
VeryAimer::Ptr very_aimer_;
|
||||
ArmorWhere::Ptr armor_where_;
|
||||
AutoAimFsmController auto_aim_fsm_cl_;
|
||||
AutoExposureCfg::Ptr auto_exposure_cfg_;
|
||||
|
||||
cv::Rect expanded_;
|
||||
int max_detect_armors_;
|
||||
bool run_flag_ = false;
|
||||
int detect_finish_count_ = 0;
|
||||
int img_recv_count_ = 0;
|
||||
int tracker_finish_count_ = 0;
|
||||
int fire_count_ = 0;
|
||||
int timer_cout_ = 0;
|
||||
Target target_;
|
||||
bool debug_mode_ = false;
|
||||
AutoAimDebug auto_aim_debug_;
|
||||
std::unique_ptr<wust_vl::common::concurrency::Averager<double>> latency_averager_;
|
||||
TFConfig::Ptr tf_config_;
|
||||
std::pair<cv::Mat, cv::Mat> camera_info_;
|
||||
Eigen::Matrix4d T_camera_to_odom_;
|
||||
std::mutex target_mutex_;
|
||||
std::mutex dbg_mutex_;
|
||||
};
|
||||
AutoAim::AutoAim(
|
||||
const std::string& config_path,
|
||||
TFConfig::Ptr tf_config,
|
||||
const std::pair<cv::Mat, cv::Mat>& camera_info,
|
||||
bool debug
|
||||
):
|
||||
_impl(std::make_unique<Impl>(config_path, tf_config, camera_info, debug)) {}
|
||||
AutoAim::~AutoAim() {
|
||||
_impl.reset();
|
||||
}
|
||||
|
||||
void AutoAim::start() {
|
||||
_impl->start();
|
||||
}
|
||||
void AutoAim::pushInput(CommonFrame& frame) {
|
||||
_impl->pushInput(frame);
|
||||
}
|
||||
|
||||
GimbalCmd AutoAim::solve(double bullet_speed) {
|
||||
return _impl->solve(bullet_speed);
|
||||
}
|
||||
wust_vl::common::concurrency::MonitoredThread::Ptr AutoAim::getThread() {
|
||||
return _impl->processing_thread_;
|
||||
}
|
||||
Target AutoAim::getTarget() {
|
||||
return _impl->getTarget();
|
||||
}
|
||||
void AutoAim::doDebug() {
|
||||
_impl->doDebug();
|
||||
}
|
||||
wust_vl::common::utils::Parameter::Ptr AutoAim::getParameter() {
|
||||
return _impl->auto_aim_config_parameter_;
|
||||
}
|
||||
VeryAimer::Ptr AutoAim::getVeryAimer() {
|
||||
return _impl->very_aimer_;
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
46
wust_vision-main/tasks/auto_aim/auto_aim.hpp
Normal file
46
wust_vision-main/tasks/auto_aim/auto_aim.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#pragma once
|
||||
|
||||
#include "tasks/auto_aim/armor_control/very_aimer.hpp"
|
||||
#include "tasks/auto_aim/armor_tracker/target.hpp"
|
||||
#include "tasks/imodule.hpp"
|
||||
#include "tasks/type_common.hpp"
|
||||
#include "wust_vl/video/camera.hpp"
|
||||
#include <memory>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
|
||||
class AutoAim: public IModule {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<AutoAim>;
|
||||
AutoAim(
|
||||
const std::string& config_path,
|
||||
TFConfig::Ptr tf_config,
|
||||
const std::pair<cv::Mat, cv::Mat>& camera_info,
|
||||
bool debug
|
||||
);
|
||||
static Ptr create(
|
||||
const std::string& config_path,
|
||||
TFConfig::Ptr tf_config,
|
||||
const std::pair<cv::Mat, cv::Mat>& camera_info,
|
||||
bool debug
|
||||
) {
|
||||
return std::make_shared<AutoAim>(config_path, tf_config, camera_info, debug);
|
||||
}
|
||||
~AutoAim();
|
||||
|
||||
void start() override;
|
||||
void doDebug() override;
|
||||
void pushInput(CommonFrame& frame) override;
|
||||
Target getTarget();
|
||||
GimbalCmd solve(double bullet_speed) override;
|
||||
wust_vl::common::concurrency::MonitoredThread::Ptr getThread() override;
|
||||
wust_vl::common::utils::Parameter::Ptr getParameter();
|
||||
VeryAimer::Ptr getVeryAimer();
|
||||
struct Impl;
|
||||
std::unique_ptr<Impl> _impl;
|
||||
};
|
||||
inline AutoAim::Ptr toAutoAim(IModule::Ptr module) {
|
||||
return std::dynamic_pointer_cast<AutoAim>(module);
|
||||
}
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
134
wust_vision-main/tasks/auto_aim/auto_aim_fsm.hpp
Normal file
134
wust_vision-main/tasks/auto_aim/auto_aim_fsm.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
#pragma once
|
||||
|
||||
#include "wust_vl/common/utils/parameter.hpp"
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
enum class AutoAimFsm {
|
||||
AIM_WHOLE_CAR_ARMOR,
|
||||
AIM_WHOLE_CAR_CENTER,
|
||||
AIM_SINGLE_ARMOR,
|
||||
AIM_WHOLE_CAR_PAIR
|
||||
};
|
||||
|
||||
inline std::string auto_aim_fsm_to_string(AutoAimFsm state) {
|
||||
switch (state) {
|
||||
case AutoAimFsm::AIM_WHOLE_CAR_ARMOR:
|
||||
return "AIM_WHOLE_CAR_ARMOR";
|
||||
case AutoAimFsm::AIM_WHOLE_CAR_CENTER:
|
||||
return "AIM_WHOLE_CAR_CENTER";
|
||||
case AutoAimFsm::AIM_SINGLE_ARMOR:
|
||||
return "AIM_SINGLE_ARMOR";
|
||||
case AutoAimFsm::AIM_WHOLE_CAR_PAIR:
|
||||
return "AIM_WHOLE_CAR_PAIR";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
class AutoAimFsmController {
|
||||
public:
|
||||
AutoAimFsmController() {
|
||||
config_ = std::make_shared<AutoAimFsmConfig>();
|
||||
}
|
||||
AutoAimFsm fsm_state_ { AutoAimFsm::AIM_SINGLE_ARMOR };
|
||||
struct AutoAimFsmConfig: wust_vl::common::utils::ParamGroup {
|
||||
public:
|
||||
static constexpr const char* Logger = "Config: auto_aim::auto_aim_fsm";
|
||||
static constexpr const char* kKey = "auto_aim_fsm";
|
||||
const char* key() const override {
|
||||
return kKey;
|
||||
}
|
||||
using Ptr = std::shared_ptr<AutoAimFsmConfig>;
|
||||
AutoAimFsmConfig() {}
|
||||
GEN_PARAM(int, transfer_thresh);
|
||||
GEN_PARAM(double, single_whole_up);
|
||||
GEN_PARAM(double, single_whole_down);
|
||||
GEN_PARAM(double, whole_pair_up);
|
||||
GEN_PARAM(double, whole_pair_down);
|
||||
GEN_PARAM(double, pair_center_up);
|
||||
GEN_PARAM(double, pair_center_down);
|
||||
void loadSelf(const YAML::Node& node) override {
|
||||
transfer_thresh_param.load(node);
|
||||
single_whole_up_param.load(node);
|
||||
single_whole_down_param.load(node);
|
||||
whole_pair_up_param.load(node);
|
||||
whole_pair_down_param.load(node);
|
||||
pair_center_up_param.load(node);
|
||||
pair_center_down_param.load(node);
|
||||
}
|
||||
};
|
||||
AutoAimFsmConfig::Ptr config_;
|
||||
int overflow_count_ = 0;
|
||||
|
||||
void update(double v_yaw, bool target_jumped) {
|
||||
// 无跳变:直接退回单装甲,并清状态
|
||||
if (!target_jumped) {
|
||||
fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR;
|
||||
overflow_count_ = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const double av = std::abs(v_yaw);
|
||||
|
||||
switch (fsm_state_) {
|
||||
case AutoAimFsm::AIM_SINGLE_ARMOR: {
|
||||
overflow_count_ =
|
||||
(av > config_->single_whole_up_param.get()) ? overflow_count_ + 1 : 0;
|
||||
if (overflow_count_ > config_->transfer_thresh_param.get()) {
|
||||
fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_ARMOR;
|
||||
overflow_count_ = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case AutoAimFsm::AIM_WHOLE_CAR_ARMOR: {
|
||||
if (av > config_->whole_pair_up_param.get())
|
||||
++overflow_count_;
|
||||
else if (av < config_->single_whole_down_param.get())
|
||||
--overflow_count_;
|
||||
else
|
||||
overflow_count_ = 0;
|
||||
|
||||
if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) {
|
||||
fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_PAIR
|
||||
: AutoAimFsm::AIM_SINGLE_ARMOR;
|
||||
overflow_count_ = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case AutoAimFsm::AIM_WHOLE_CAR_PAIR: {
|
||||
if (av > config_->pair_center_up_param.get())
|
||||
++overflow_count_;
|
||||
else if (av < config_->whole_pair_down_param.get())
|
||||
--overflow_count_;
|
||||
else
|
||||
overflow_count_ = 0;
|
||||
|
||||
if (std::abs(overflow_count_) > config_->transfer_thresh_param.get()) {
|
||||
fsm_state_ = (overflow_count_ > 0) ? AutoAimFsm::AIM_WHOLE_CAR_CENTER
|
||||
: AutoAimFsm::AIM_WHOLE_CAR_ARMOR;
|
||||
overflow_count_ = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case AutoAimFsm::AIM_WHOLE_CAR_CENTER: {
|
||||
overflow_count_ =
|
||||
(av < config_->pair_center_down_param.get()) ? overflow_count_ + 1 : 0;
|
||||
if (overflow_count_ > config_->transfer_thresh_param.get()) {
|
||||
fsm_state_ = AutoAimFsm::AIM_WHOLE_CAR_PAIR;
|
||||
overflow_count_ = 0;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
fsm_state_ = AutoAimFsm::AIM_SINGLE_ARMOR;
|
||||
overflow_count_ = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
684
wust_vision-main/tasks/auto_aim/debug.cpp
Normal file
684
wust_vision-main/tasks/auto_aim/debug.cpp
Normal file
@@ -0,0 +1,684 @@
|
||||
#include "debug.hpp"
|
||||
namespace wust_vision::auto_aim {
|
||||
void drawDebugArmorContent(
|
||||
cv::Mat& debug_img,
|
||||
const AutoAimDebug& dbg,
|
||||
std::pair<cv::Mat, cv::Mat> camera_info
|
||||
) {
|
||||
if (debug_img.empty()) {
|
||||
std::cout << "debug_img is empty" << std::endl;
|
||||
return;
|
||||
}
|
||||
const auto now = std::chrono::steady_clock::now();
|
||||
const auto& armors = dbg.armors;
|
||||
const auto& gimbal_cmd = dbg.gimbal_cmd;
|
||||
const auto& target = dbg.target;
|
||||
auto aim_target = dbg.aim_target;
|
||||
const auto& armor_objs = dbg.armor_objs;
|
||||
const cv::Rect img_rect(0, 0, debug_img.cols, debug_img.rows);
|
||||
const cv::Rect roi = dbg.expanded & img_rect;
|
||||
cv::rectangle(debug_img, roi, cv::Scalar(255, 255, 255), 2);
|
||||
|
||||
static const int next_indices[] = { 2, 0, 3, 1 };
|
||||
|
||||
for (size_t i = 0; i < armor_objs.size(); i++) {
|
||||
const auto pts = armor_objs[i].toPts();
|
||||
|
||||
for (size_t j = 0; j < 4; ++j) {
|
||||
const cv::Scalar color =
|
||||
armor_objs[i].is_ok ? cv::Scalar(50, 255, 50) : cv::Scalar(50, 255, 255);
|
||||
cv::line(debug_img, pts[j], pts[next_indices[j]], color, 2);
|
||||
}
|
||||
|
||||
auto armorName = [](auto_aim::ArmorNumber num) {
|
||||
switch (num) {
|
||||
case auto_aim::ArmorNumber::SENTRY:
|
||||
return "SENTRY";
|
||||
case auto_aim::ArmorNumber::BASE:
|
||||
return "BASE";
|
||||
case auto_aim::ArmorNumber::OUTPOST:
|
||||
return "OUTPOST";
|
||||
case auto_aim::ArmorNumber::NO1:
|
||||
return "NO1";
|
||||
case auto_aim::ArmorNumber::NO2:
|
||||
return "NO2";
|
||||
case auto_aim::ArmorNumber::NO3:
|
||||
return "NO3";
|
||||
case auto_aim::ArmorNumber::NO4:
|
||||
return "NO4";
|
||||
case auto_aim::ArmorNumber::NO5:
|
||||
return "NO5";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
};
|
||||
|
||||
const std::string armor_str = armorName(armor_objs[i].number);
|
||||
cv::putText(
|
||||
debug_img,
|
||||
armor_str,
|
||||
pts[1] + cv::Point2f(0, 50),
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
0.7,
|
||||
cv::Scalar(0, 200, 200),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
const std::string latency_str = fmt::format("Latency: {:.2f}ms", dbg.latency_ms);
|
||||
cv::putText(
|
||||
debug_img,
|
||||
latency_str,
|
||||
cv::Point(10, 30),
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
0.8,
|
||||
cv::Scalar(255, 255, 255),
|
||||
2
|
||||
);
|
||||
|
||||
{
|
||||
// static std::deque<std::pair<Eigen::Vector3d, double>> traj3d;
|
||||
|
||||
// double _now =
|
||||
// std::chrono::duration<double>(std::chrono::steady_clock::now().time_since_epoch()).count();
|
||||
|
||||
// traj3d.emplace_back(aim_target.pos, _now);
|
||||
|
||||
// while (!traj3d.empty() && _now - traj3d.front().second > 1.0)
|
||||
// traj3d.pop_front();
|
||||
aim_target.tf(dbg.T_camera_to_odom.inverse());
|
||||
const auto pts = aim_target.toPts(camera_info.first, camera_info.second);
|
||||
|
||||
if (!pts.empty()) {
|
||||
// if (traj3d.size() > 1) {
|
||||
// std::vector<std::pair<cv::Point, double>> img_pts;
|
||||
|
||||
// for (auto& p: traj3d) {
|
||||
// auto p3d_odom = p.first;
|
||||
|
||||
// Eigen::Vector4d p_odom(p3d_odom.x(), p3d_odom.y(), p3d_odom.z(), 1);
|
||||
|
||||
// Eigen::Vector4d p_camera = dbg.T_camera_to_odom.inverse() * p_odom;
|
||||
|
||||
// std::vector<cv::Point3f> obj;
|
||||
// obj.emplace_back(p_camera.x(), p_camera.y(), p_camera.z());
|
||||
|
||||
// std::vector<cv::Point2f> proj;
|
||||
|
||||
// cv::projectPoints(
|
||||
// obj,
|
||||
// cv::Vec3d(0, 0, 0),
|
||||
// cv::Vec3d(0, 0, 0),
|
||||
// camera_info.first,
|
||||
// camera_info.second,
|
||||
// proj
|
||||
// );
|
||||
|
||||
// if (!proj.empty()) {
|
||||
// const auto& pt = proj[0];
|
||||
|
||||
// if (std::isfinite(pt.x) && std::isfinite(pt.y)) {
|
||||
// img_pts.emplace_back(cv::Point(int(pt.x), int(pt.y)), p.second);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if (img_pts.size() >= 2) {
|
||||
// double now = std::chrono::duration<double>(
|
||||
// std::chrono::steady_clock::now().time_since_epoch()
|
||||
// )
|
||||
// .count();
|
||||
|
||||
// const double max_age = 1.0;
|
||||
|
||||
// for (size_t i = 1; i < img_pts.size(); ++i) {
|
||||
// double age = now - img_pts[i].second;
|
||||
|
||||
// double t = std::clamp(age / max_age, 0.0, 1.0);
|
||||
|
||||
// int r = int(255 * (1.0 - t));
|
||||
// int b = int(255 * t);
|
||||
|
||||
// cv::Scalar color(b, 0, r);
|
||||
|
||||
// cv::line(
|
||||
// debug_img,
|
||||
// img_pts[i - 1].first,
|
||||
// img_pts[i].first,
|
||||
// color,
|
||||
// 2,
|
||||
// cv::LINE_AA
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
cv::Point2f center(0.f, 0.f);
|
||||
|
||||
for (auto pt: pts)
|
||||
center += pt;
|
||||
|
||||
center *= 1.0f / pts.size();
|
||||
|
||||
cv::Scalar color(255, 255, 255);
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
cv::line(debug_img, pts[i], pts[(i + 1) % 4], color, 2);
|
||||
|
||||
for (int i = 4; i < 8; i++)
|
||||
cv::line(debug_img, pts[i], pts[4 + (i + 1) % 4], color, 2);
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
cv::line(debug_img, pts[i], pts[i + 4], color, 2);
|
||||
|
||||
if (gimbal_cmd.fire_advice) {
|
||||
int cross_len = 60;
|
||||
|
||||
cv::line(
|
||||
debug_img,
|
||||
center + cv::Point2f(-cross_len, -cross_len),
|
||||
center + cv::Point2f(+cross_len, +cross_len),
|
||||
cv::Scalar(0, 0, 255),
|
||||
5
|
||||
);
|
||||
|
||||
cv::line(
|
||||
debug_img,
|
||||
center + cv::Point2f(-cross_len, +cross_len),
|
||||
center + cv::Point2f(+cross_len, -cross_len),
|
||||
cv::Scalar(0, 0, 255),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
const double scale = 10.0;
|
||||
|
||||
const double v_yaw = gimbal_cmd.v_yaw;
|
||||
const double v_pitch = gimbal_cmd.v_pitch;
|
||||
|
||||
const double dx = -scale * v_yaw;
|
||||
const double dy = scale * v_pitch;
|
||||
|
||||
const cv::Point2f start_pt = center;
|
||||
const cv::Point2f end_pt = start_pt + cv::Point2f(dx, dy);
|
||||
|
||||
const cv::Scalar color_x =
|
||||
dbg.detect_color ? cv::Scalar(255, 50, 50) : cv::Scalar(50, 50, 255);
|
||||
|
||||
cv::arrowedLine(debug_img, start_pt, end_pt, color_x, 4, cv::LINE_AA, 0, 0.2);
|
||||
}
|
||||
}
|
||||
std::vector<cv::Point2f> all_corners;
|
||||
|
||||
auto visualizeTargetProjection = [&](auto_aim::Target armor_target) -> auto_aim::Armors {
|
||||
auto_aim::Armors armor_data;
|
||||
armor_data.timestamp = armor_target.timestamp_;
|
||||
|
||||
if (armor_target.is_tracking) {
|
||||
Eigen::Vector3d pos = armor_target.target_state_.pos();
|
||||
if (pos.norm() > 0.5) {
|
||||
armor_data.armors.clear();
|
||||
const size_t a_n = armor_target.armor_num_;
|
||||
armor_data.armors.reserve(a_n);
|
||||
const auto now = wust_vl::common::utils::time_utils::now();
|
||||
armor_target.predictSimple(now);
|
||||
const std::vector<Eigen::Vector4d> armors_posandyaw =
|
||||
armor_target.getArmorPosAndYaw();
|
||||
for (size_t i = 0; i < a_n; ++i) {
|
||||
const Eigen::Vector3d pos = { armors_posandyaw[i][0],
|
||||
armors_posandyaw[i][1],
|
||||
armors_posandyaw[i][2] };
|
||||
Eigen::Vector3d euler;
|
||||
euler.z() = M_PI / 2.0;
|
||||
euler.y() = (armor_target.tracked_id_ == auto_aim::ArmorNumber::OUTPOST)
|
||||
? -0.2618
|
||||
: 0.2618;
|
||||
euler.x() = armors_posandyaw[i][3];
|
||||
const Eigen::Quaterniond ori =
|
||||
utils::eulerToQuat(euler, utils::EulerOrder::ZYX);
|
||||
armor_data.armors.emplace_back(auto_aim::Armor {
|
||||
.type = armor_target.type_,
|
||||
.pos = pos,
|
||||
.ori = ori,
|
||||
.is_ok = true,
|
||||
.id = (int)(i),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return armor_data;
|
||||
};
|
||||
auto armor_data = visualizeTargetProjection(dbg.target);
|
||||
transformArmorData(armor_data, dbg.T_camera_to_odom.inverse());
|
||||
for (size_t i = 0; i < armor_data.armors.size(); ++i) {
|
||||
const auto& pts = armor_data.armors[i].toPtsDebug(camera_info.first, camera_info.second);
|
||||
const auto& pos = armor_data.armors[i].pos;
|
||||
const auto& ori = armor_data.armors[i].ori;
|
||||
const auto& id = armor_data.armors[i].id;
|
||||
cv::Scalar color;
|
||||
if (dbg.detect_color) {
|
||||
color = cv::Scalar(255, 0, 0);
|
||||
} else {
|
||||
color = cv::Scalar(0, 0, 255);
|
||||
}
|
||||
// 绘制前表面
|
||||
for (size_t j = 0; j < 4; ++j) {
|
||||
cv::line(debug_img, pts[j], pts[(j + 1) % 4], color, 2);
|
||||
}
|
||||
|
||||
// 绘制后表面
|
||||
for (size_t j = 4; j < 8; ++j) {
|
||||
cv::line(debug_img, pts[j], pts[4 + (j + 1) % 4], color, 2);
|
||||
}
|
||||
|
||||
// 绘制侧边
|
||||
for (size_t j = 0; j < 4; ++j) {
|
||||
cv::line(debug_img, pts[j], pts[j + 4], color, 2);
|
||||
}
|
||||
|
||||
all_corners.insert(all_corners.end(), pts.begin(), pts.end());
|
||||
|
||||
const Eigen::Vector3d euler = ori.toRotationMatrix().eulerAngles(2, 1, 0);
|
||||
const double yaw = euler[0];
|
||||
const double distance =
|
||||
std::sqrt(pos.x() * pos.x() + pos.y() * pos.y() + pos.z() * pos.z());
|
||||
|
||||
const std::vector<std::string> info_lines = {
|
||||
fmt::format("Dis: {:.1f}cm", distance * 100),
|
||||
fmt::format("X: {:.2f}", pos.x()),
|
||||
fmt::format("Y: {:.2f}", pos.y()),
|
||||
fmt::format("Z: {:.2f}", pos.z()),
|
||||
fmt::format("Yaw: {:.2f}", yaw * 180.0 / M_PI),
|
||||
fmt::format("ID: {:d}", id)
|
||||
};
|
||||
|
||||
const cv::Point2f text_org = pts[0] + cv::Point2f(0, 200);
|
||||
for (int k = 0; k < info_lines.size(); ++k) {
|
||||
cv::putText(
|
||||
debug_img,
|
||||
info_lines[k],
|
||||
text_org + cv::Point2f(0, -10 - 20 * k),
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
cv::Scalar(50, 255, 255),
|
||||
1
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!all_corners.empty()) {
|
||||
cv::Point2f avg(0.f, 0.f);
|
||||
for (const auto& pt: all_corners)
|
||||
avg += pt;
|
||||
avg *= 1.0f / all_corners.size();
|
||||
cv::circle(debug_img, avg, 10, cv::Scalar(50, 255, 50), -1);
|
||||
|
||||
const double scale = 50.0;
|
||||
const double dy = scale * target.target_state_.vyaw();
|
||||
const cv::Point2f start_pt = avg;
|
||||
const cv::Point2f end_pt = start_pt + cv::Point2f(0, dy);
|
||||
cv::arrowedLine(
|
||||
debug_img,
|
||||
start_pt,
|
||||
end_pt,
|
||||
cv::Scalar(50, 255, 50),
|
||||
3,
|
||||
cv::LINE_AA,
|
||||
0,
|
||||
0.1
|
||||
);
|
||||
cv::putText(
|
||||
debug_img,
|
||||
fmt::format("V_yaw: {:.2f}", target.target_state_.vyaw()),
|
||||
avg + cv::Point2f(0, -20),
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
1.0,
|
||||
cv::Scalar(50, 255, 50),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
std::string state_str;
|
||||
state_str = auto_aim_fsm_to_string(dbg.fsm);
|
||||
|
||||
int baseline = 0;
|
||||
cv::Size text_size = cv::getTextSize(state_str, cv::FONT_HERSHEY_SIMPLEX, 2.5, 2, &baseline);
|
||||
|
||||
// 保证在图像内
|
||||
const int x =
|
||||
std::clamp(debug_img.cols - text_size.width - 10, 0, debug_img.cols - text_size.width);
|
||||
const int y = std::clamp(text_size.height + 10, text_size.height, debug_img.rows - 1);
|
||||
|
||||
cv::putText(
|
||||
debug_img,
|
||||
state_str,
|
||||
{ x, y },
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
2.5,
|
||||
cv::Scalar(0, 0, 255),
|
||||
2
|
||||
);
|
||||
|
||||
const std::string id_str =
|
||||
fmt::format("Attack: {}", armorNumberToString(dbg.target.tracked_id_));
|
||||
const cv::Size id_size = cv::getTextSize(id_str, cv::FONT_HERSHEY_SIMPLEX, 1.6, 2, &baseline);
|
||||
|
||||
// 保证在图像内
|
||||
const int id_x = std::clamp(debug_img.cols - 300, 0, debug_img.cols - id_size.width - 10);
|
||||
const int id_y = std::clamp(150, id_size.height, debug_img.rows - 1);
|
||||
|
||||
cv::putText(
|
||||
debug_img,
|
||||
id_str,
|
||||
{ id_x, id_y },
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
1.6,
|
||||
cv::Scalar(255, 0, 255),
|
||||
2
|
||||
);
|
||||
|
||||
if (gimbal_cmd.fire_advice) {
|
||||
std::string fire_str = "Fire!";
|
||||
cv::putText(
|
||||
debug_img,
|
||||
fire_str,
|
||||
{ debug_img.cols / 2 - 100, 200 },
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
2.85,
|
||||
cv::Scalar(0, 0, 255),
|
||||
2
|
||||
);
|
||||
}
|
||||
|
||||
const std::string gimbal_str = fmt::format(
|
||||
"Pitch: {:.2f}, Yaw: {:.2f}, Enable_pitch_diff: {:.2f}, Enable_yaw_diff: {:.2f}, V_yaw: {:.2f}, V_pitch: {:.2f}",
|
||||
gimbal_cmd.pitch,
|
||||
gimbal_cmd.yaw,
|
||||
gimbal_cmd.enable_pitch_diff,
|
||||
gimbal_cmd.enable_yaw_diff,
|
||||
gimbal_cmd.v_yaw,
|
||||
gimbal_cmd.v_pitch
|
||||
);
|
||||
cv::putText(
|
||||
debug_img,
|
||||
gimbal_str,
|
||||
{ 10, debug_img.rows - 30 },
|
||||
cv::FONT_HERSHEY_SIMPLEX,
|
||||
0.8,
|
||||
cv::Scalar(255, 255, 0),
|
||||
2
|
||||
);
|
||||
|
||||
double scale = 100.0;
|
||||
double armor_len = 0.135;
|
||||
|
||||
std::vector<Eigen::Vector2d> pts;
|
||||
pts.reserve(armors.armors.size() + armor_data.armors.size());
|
||||
|
||||
auto collect_xy = [&](auto& list, bool use_target) {
|
||||
for (auto& a: list)
|
||||
pts.emplace_back(
|
||||
use_target ? a.target_pos.x() : a.pos.x(),
|
||||
use_target ? a.target_pos.y() : a.pos.y()
|
||||
);
|
||||
};
|
||||
|
||||
collect_xy(armors.armors, true);
|
||||
collect_xy(armor_data.armors, false);
|
||||
|
||||
double max_abs_x = 1e-6, max_abs_y = 1e-6;
|
||||
for (auto& p: pts) {
|
||||
max_abs_x = std::max(max_abs_x, std::abs(p.x()));
|
||||
max_abs_y = std::max(max_abs_y, std::abs(p.y()));
|
||||
}
|
||||
|
||||
const double margin = 200.0;
|
||||
const double cx = debug_img.cols * 0.5;
|
||||
const double cy = debug_img.rows * 0.5;
|
||||
|
||||
scale = std::min({ (cx - margin) / max_abs_x,
|
||||
(debug_img.cols - cx - margin) / max_abs_x,
|
||||
(cy - margin) / max_abs_y,
|
||||
(debug_img.rows - cy - margin) / max_abs_y,
|
||||
550.0 });
|
||||
|
||||
const cv::Point2d origin(cx, cy);
|
||||
|
||||
auto to_img = [&](const Eigen::Vector3d& p) {
|
||||
return cv::Point2d(origin.x + p.x() * scale, origin.y - p.y() * scale);
|
||||
};
|
||||
|
||||
auto draw2dArmor = [&](const Eigen::Vector3d& pos, double yaw, const cv::Scalar& color) {
|
||||
cv::Point2d C = to_img(pos);
|
||||
cv::circle(debug_img, C, 3, color, -1, cv::LINE_AA);
|
||||
|
||||
double nx = -sin(yaw), ny = cos(yaw);
|
||||
double half_len_px = armor_len * 0.5 * scale;
|
||||
|
||||
cv::Point2d P1(C.x + nx * half_len_px, C.y - ny * half_len_px);
|
||||
cv::Point2d P2(C.x - nx * half_len_px, C.y + ny * half_len_px);
|
||||
cv::line(debug_img, P1, P2, color, 2, cv::LINE_AA);
|
||||
};
|
||||
|
||||
Eigen::Vector3d center(0, 0, 0);
|
||||
if (!armor_data.armors.empty()) {
|
||||
for (auto& a: armor_data.armors)
|
||||
center += a.pos;
|
||||
center /= armor_data.armors.size();
|
||||
}
|
||||
|
||||
const cv::Point2d Cc = to_img(center);
|
||||
if (!armor_data.armors.empty())
|
||||
cv::circle(debug_img, Cc, 5, cv::Scalar(255, 0, 0), -1, cv::LINE_AA);
|
||||
|
||||
for (auto& a: armors.armors)
|
||||
draw2dArmor(a.target_pos, a.yaw, cv::Scalar(0, 255, 255));
|
||||
|
||||
std::vector<cv::Point2d> data_pts;
|
||||
|
||||
for (auto& a: armor_data.armors) {
|
||||
double yaw = a.ori.toRotationMatrix().eulerAngles(2, 1, 0)[0];
|
||||
draw2dArmor(a.pos, yaw, cv::Scalar(255, 255, 255));
|
||||
data_pts.push_back(to_img(a.pos));
|
||||
}
|
||||
|
||||
for (auto& pt: data_pts)
|
||||
cv::line(debug_img, Cc, pt, cv::Scalar(180, 180, 255), 1, cv::LINE_AA);
|
||||
|
||||
for (auto& a: armors.armors)
|
||||
cv::line(debug_img, Cc, to_img(a.target_pos), cv::Scalar(0, 150, 255), 1, cv::LINE_AA);
|
||||
|
||||
cv::circle(
|
||||
debug_img,
|
||||
cv::Point2i(debug_img.cols / 2, debug_img.rows / 2),
|
||||
5,
|
||||
cv::Scalar(255, 255, 255),
|
||||
2
|
||||
);
|
||||
}
|
||||
void writeTargetLogToJson(const auto_aim::Target& armor_target) {
|
||||
nlohmann::json j;
|
||||
|
||||
// -------- armor_target 部分 --------
|
||||
nlohmann::json jt;
|
||||
jt["type"] = armor_target.type_;
|
||||
jt["tracking"] = armor_target.is_tracking;
|
||||
jt["id"] = static_cast<int>(armor_target.tracked_id_);
|
||||
jt["armors_num"] = armor_target.armor_num_;
|
||||
|
||||
const auto now = std::chrono::steady_clock::now();
|
||||
const auto age_ms_t =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(now - armor_target.timestamp_)
|
||||
.count();
|
||||
jt["timestamp_age_ms"] = age_ms_t;
|
||||
|
||||
jt["position"] = { { "x", armor_target.target_state_.cx() },
|
||||
{ "y", armor_target.target_state_.cy() },
|
||||
{ "z", armor_target.target_state_.cz() } };
|
||||
|
||||
jt["velocity"] = { { "x", armor_target.target_state_.vcx() },
|
||||
{ "y", armor_target.target_state_.vcy() },
|
||||
{ "z", armor_target.target_state_.vcz() } };
|
||||
|
||||
jt["r"] = armor_target.target_state_.r();
|
||||
jt["l"] = armor_target.target_state_.l();
|
||||
jt["h"] = armor_target.target_state_.h();
|
||||
jt["yaw"] = armor_target.target_state_.yaw();
|
||||
jt["v_yaw"] = armor_target.target_state_.vyaw();
|
||||
j["armor_target"] = jt;
|
||||
// -------- 写文件 --------
|
||||
std::ofstream file("/dev/shm/target_log.json");
|
||||
if (file.is_open()) {
|
||||
file << j.dump(2);
|
||||
}
|
||||
}
|
||||
|
||||
struct DebugLogs {
|
||||
#define DEBUG_LOG_LIST(X) \
|
||||
X(double, 100, time) \
|
||||
X(double, 100, raw_yaw) \
|
||||
X(double, 100, raw_pitch) \
|
||||
X(double, 100, yaw) \
|
||||
X(double, 100, pitch) \
|
||||
X(double, 100, armor_dis) \
|
||||
X(double, 100, armor_x) \
|
||||
X(double, 100, armor_y) \
|
||||
X(double, 100, armor_z) \
|
||||
X(double, 100, armor_yaw) \
|
||||
X(double, 100, ypd_y) \
|
||||
X(double, 100, ypd_p) \
|
||||
X(double, 100, gimbal_yaw) \
|
||||
X(double, 100, gimbal_pitch) \
|
||||
X(double, 100, target_v_yaw) \
|
||||
X(double, 100, control_v_yaw) \
|
||||
X(double, 100, control_v_pitch) \
|
||||
X(double, 100, yaw_diff) \
|
||||
X(double, 100, fire) \
|
||||
X(double, 100, rune_dis) \
|
||||
X(double, 100, fly_time) \
|
||||
X(double, 100, control_a_yaw) \
|
||||
X(double, 100, control_a_pitch)
|
||||
#define GEN_LOG(TYPE, SIZE, NAME) LogsStream<TYPE, SIZE> NAME##_log { #NAME };
|
||||
|
||||
#define X(TYPE, SIZE, NAME) GEN_LOG(TYPE, SIZE, NAME)
|
||||
DEBUG_LOG_LIST(X)
|
||||
#undef X
|
||||
|
||||
void clear() {
|
||||
#define X(TYPE, SIZE, NAME) NAME##_log.clear();
|
||||
DEBUG_LOG_LIST(X)
|
||||
#undef X
|
||||
}
|
||||
};
|
||||
|
||||
void debuglog(const AutoAimDebug& dbg_armor) {
|
||||
static bool first_log = true;
|
||||
static std::chrono::steady_clock::time_point start_time;
|
||||
|
||||
static auto_aim::Armor last_armor_;
|
||||
static double last_armor_yaw_ = 0.0;
|
||||
static double last_ypd_y_ = 0.0;
|
||||
static double last_ypd_p_ = 0.0;
|
||||
static double last_distance_ = 0.0;
|
||||
static DebugLogs log;
|
||||
static GimbalCmd last_cmd_;
|
||||
static double rune_dis = 0.0;
|
||||
if (first_log) {
|
||||
start_time = std::chrono::steady_clock::now();
|
||||
first_log = false;
|
||||
}
|
||||
const auto now = std::chrono::steady_clock::now();
|
||||
const auto_aim::Armors& armors = dbg_armor.armors;
|
||||
const double t = std::chrono::duration<double>(now - start_time).count();
|
||||
const auto_aim::Target& target = dbg_armor.target;
|
||||
writeTargetLogToJson(target);
|
||||
|
||||
double armor_yaw = 0.0, ypd_y = 0.0, ypd_p = 0.0, armor_distance = 0.0;
|
||||
|
||||
if (!armors.armors.empty()) {
|
||||
std::vector<auto_aim::Armor> ok_armors;
|
||||
for (const auto& armor: armors.armors) {
|
||||
if (armor.number != auto_aim::ArmorNumber::OUTPOST)
|
||||
ok_armors.push_back(armor);
|
||||
}
|
||||
|
||||
if (!ok_armors.empty()) {
|
||||
const auto_aim::Armor& min_armor = *std::min_element(
|
||||
ok_armors.begin(),
|
||||
ok_armors.end(),
|
||||
[](const auto_aim::Armor& a, const auto_aim::Armor& b) {
|
||||
return a.distance_to_image_center < b.distance_to_image_center;
|
||||
}
|
||||
);
|
||||
|
||||
last_armor_ = min_armor;
|
||||
|
||||
armor_distance = std::hypot(
|
||||
min_armor.target_pos.x(),
|
||||
min_armor.target_pos.y(),
|
||||
min_armor.target_pos.z()
|
||||
);
|
||||
auto orientationToYaw = [](const Eigen::Quaterniond& q) noexcept -> double {
|
||||
Eigen::Vector3d euler = utils::quatToEuler(q, utils::EulerOrder::ZYX, false);
|
||||
double yaw = euler[0];
|
||||
yaw = last_armor_yaw_ + angles::shortest_angular_distance(last_armor_yaw_, yaw);
|
||||
last_armor_yaw_ = yaw;
|
||||
return yaw;
|
||||
};
|
||||
|
||||
armor_yaw = orientationToYaw(min_armor.target_ori);
|
||||
|
||||
ypd_y = std::atan2(min_armor.target_pos.y(), min_armor.target_pos.x());
|
||||
ypd_y = last_ypd_y_ + angles::shortest_angular_distance(last_ypd_y_, ypd_y);
|
||||
last_ypd_y_ = ypd_y;
|
||||
|
||||
ypd_p = std::atan2(
|
||||
min_armor.target_pos.z(),
|
||||
std::hypot(min_armor.target_pos.x(), min_armor.target_pos.y())
|
||||
);
|
||||
last_ypd_p_ = ypd_p;
|
||||
|
||||
last_distance_ = armor_distance;
|
||||
}
|
||||
}
|
||||
GimbalCmd i_use;
|
||||
if (dbg_armor.gimbal_cmd.appear) {
|
||||
i_use = dbg_armor.gimbal_cmd;
|
||||
} else {
|
||||
i_use = last_cmd_;
|
||||
}
|
||||
last_cmd_ = i_use;
|
||||
nlohmann::json j;
|
||||
log.time_log.handleOnce(t, j);
|
||||
log.raw_yaw_log.handleOnce(i_use.target_yaw, j);
|
||||
log.raw_pitch_log.handleOnce(i_use.target_pitch, j);
|
||||
log.yaw_log.handleOnce(i_use.yaw, j);
|
||||
log.pitch_log.handleOnce(i_use.pitch, j);
|
||||
log.armor_yaw_log.handleOnce(armor_yaw * 180.0 / M_PI, j);
|
||||
log.armor_x_log.handleOnce(last_armor_.target_pos.x(), j);
|
||||
log.armor_y_log.handleOnce(last_armor_.target_pos.y(), j);
|
||||
log.armor_z_log.handleOnce(last_armor_.target_pos.z(), j);
|
||||
log.ypd_y_log.handleOnce(last_ypd_y_ * 180.0 / M_PI, j);
|
||||
log.ypd_p_log.handleOnce(last_ypd_p_ * 180.0 / M_PI, j);
|
||||
log.armor_dis_log.handleOnce(last_distance_, j);
|
||||
log.gimbal_pitch_log.handleOnce(dbg_armor.gimbal_py.first * 180.0 / M_PI, j);
|
||||
log.gimbal_yaw_log.handleOnce(dbg_armor.gimbal_py.second * 180.0 / M_PI, j);
|
||||
log.target_v_yaw_log.handleOnce(target.target_state_.vyaw(), j);
|
||||
log.control_v_pitch_log.handleOnce(i_use.v_pitch, j);
|
||||
log.control_v_yaw_log.handleOnce(i_use.v_yaw, j);
|
||||
log.fire_log.handleOnce(i_use.fire_advice, j);
|
||||
log.rune_dis_log.handleOnce(rune_dis, j);
|
||||
log.fly_time_log.handleOnce(i_use.fly_time, j);
|
||||
log.control_a_yaw_log.handleOnce(i_use.a_yaw / 180.0 * M_PI, j);
|
||||
log.control_a_pitch_log.handleOnce(i_use.a_pitch / 180.0 * M_PI, j);
|
||||
log.yaw_diff_log.handleOnce(
|
||||
std::abs(dbg_armor.gimbal_py.second * 180.0 / M_PI - dbg_armor.gimbal_cmd.yaw),
|
||||
j
|
||||
);
|
||||
|
||||
std::ofstream file("/dev/shm/cmd_log.json");
|
||||
if (file.is_open()) {
|
||||
file << j.dump();
|
||||
}
|
||||
}
|
||||
} // namespace wust_vision::auto_aim
|
||||
38
wust_vision-main/tasks/auto_aim/debug.hpp
Normal file
38
wust_vision-main/tasks/auto_aim/debug.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
#include "tasks/auto_aim/armor_tracker/target.hpp"
|
||||
#include "tasks/auto_aim/auto_aim_fsm.hpp"
|
||||
#include "tasks/auto_aim/type.hpp"
|
||||
#include "tasks/utils/debug_utils.hpp"
|
||||
namespace wust_vision::auto_aim {
|
||||
|
||||
struct AutoAimDebug {
|
||||
wust_vl::video::ImageFrame img_frame;
|
||||
auto_aim::Armors armors;
|
||||
auto_aim::Target target;
|
||||
GimbalCmd gimbal_cmd;
|
||||
auto_aim::AutoAimFsm fsm;
|
||||
AimTarget aim_target;
|
||||
double latency_ms;
|
||||
Eigen::Matrix4d T_camera_to_odom;
|
||||
std::vector<auto_aim::ArmorObject> armor_objs;
|
||||
int detect_color = 0;
|
||||
cv::Rect expanded;
|
||||
std::pair<double, double> gimbal_py;
|
||||
};
|
||||
void drawDebugArmorContent(
|
||||
cv::Mat& debug_img,
|
||||
const AutoAimDebug& dbg,
|
||||
std::pair<cv::Mat, cv::Mat> camera_info
|
||||
);
|
||||
void writeTargetLogToJson(const auto_aim::Target& armor_target);
|
||||
|
||||
inline void drawDebugOverlayShm(
|
||||
const AutoAimDebug& dbg,
|
||||
std::pair<cv::Mat, cv::Mat> camera_info,
|
||||
bool auto_fps
|
||||
) {
|
||||
static ShmWriter shm { "/debug_frame" };
|
||||
drawDebugOverlayImpl(dbg, camera_info, auto_fps, drawDebugArmorContent, shm);
|
||||
}
|
||||
void debuglog(const AutoAimDebug& dbg_armor);
|
||||
} // namespace wust_vision::auto_aim
|
||||
399
wust_vision-main/tasks/auto_aim/type.cpp
Normal file
399
wust_vision-main/tasks/auto_aim/type.cpp
Normal file
@@ -0,0 +1,399 @@
|
||||
#include "type.hpp"
|
||||
#include "tasks/utils/config.hpp"
|
||||
#include "wust_vl/common/utils/logger.hpp"
|
||||
#include <numeric>
|
||||
namespace wust_vision {
|
||||
namespace auto_aim {
|
||||
|
||||
Light::Light(const std::vector<cv::Point>& contour): cv::RotatedRect(cv::minAreaRect(contour)) {
|
||||
this->center = std::accumulate(
|
||||
contour.begin(),
|
||||
contour.end(),
|
||||
cv::Point2f(0, 0),
|
||||
[n = static_cast<float>(contour.size())](const cv::Point2f& a, const cv::Point& b) {
|
||||
return a + cv::Point2f(b.x, b.y) / n;
|
||||
}
|
||||
);
|
||||
|
||||
cv::Point2f p[4];
|
||||
this->points(p);
|
||||
|
||||
std::sort(p, p + 4, [](const cv::Point2f& a, const cv::Point2f& b) { return a.y < b.y; });
|
||||
|
||||
top = (p[0] + p[1]) / 2;
|
||||
bottom = (p[2] + p[3]) / 2;
|
||||
|
||||
length = cv::norm(top - bottom);
|
||||
width = cv::norm(p[0] - p[1]);
|
||||
|
||||
axis = (top - bottom) / cv::norm(top - bottom);
|
||||
|
||||
tilt_angle =
|
||||
std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f;
|
||||
}
|
||||
|
||||
void Light::addOffset(const cv::Point2f& offset) noexcept {
|
||||
this->center += offset;
|
||||
top += offset;
|
||||
bottom += offset;
|
||||
}
|
||||
void Light::transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept {
|
||||
top = utils::transformPoint2D(transform_matrix, top);
|
||||
bottom = utils::transformPoint2D(transform_matrix, bottom);
|
||||
length = cv::norm(top - bottom);
|
||||
cv::Point2f p[4];
|
||||
this->points(p);
|
||||
|
||||
width = cv::norm(
|
||||
utils::transformPoint2D(transform_matrix, p[0])
|
||||
- utils::transformPoint2D(transform_matrix, p[1])
|
||||
);
|
||||
const cv::Point2f p0 = center;
|
||||
const cv::Point2f p1 = center + axis;
|
||||
|
||||
const cv::Point2f p0_t = utils::transformPoint2D(transform_matrix, p0);
|
||||
|
||||
const cv::Point2f p1_t = utils::transformPoint2D(transform_matrix, p1);
|
||||
|
||||
axis = p1_t - p0_t;
|
||||
axis /= cv::norm(axis);
|
||||
|
||||
tilt_angle =
|
||||
std::atan2(std::abs(top.x - bottom.x), std::abs(top.y - bottom.y)) / CV_PI * 180.0f;
|
||||
center = utils::transformPoint2D(transform_matrix, center);
|
||||
}
|
||||
int formArmorColor(const ArmorColor& color) noexcept {
|
||||
switch (color) {
|
||||
case ArmorColor::RED:
|
||||
return 0;
|
||||
case ArmorColor::BLUE:
|
||||
return 1;
|
||||
case ArmorColor::NONE:
|
||||
return 2;
|
||||
case ArmorColor::PURPLE:
|
||||
return 3;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept {
|
||||
switch (number) {
|
||||
case ArmorNumber::SENTRY:
|
||||
return os << "SENTRY";
|
||||
case ArmorNumber::NO1:
|
||||
return os << "NO1";
|
||||
case ArmorNumber::NO2:
|
||||
return os << "NO2";
|
||||
case ArmorNumber::NO3:
|
||||
return os << "NO3";
|
||||
case ArmorNumber::NO4:
|
||||
return os << "NO4";
|
||||
case ArmorNumber::NO5:
|
||||
return os << "NO5";
|
||||
case ArmorNumber::OUTPOST:
|
||||
return os << "OUTPOST";
|
||||
case ArmorNumber::BASE:
|
||||
return os << "BASE";
|
||||
case ArmorNumber::UNKNOWN:
|
||||
return os << "UNKNOWN";
|
||||
default:
|
||||
return os << "InvalidArmorNumber(" << static_cast<int>(number) << ")";
|
||||
}
|
||||
}
|
||||
|
||||
int formArmorNumber(const ArmorNumber& number) noexcept {
|
||||
switch (number) {
|
||||
case ArmorNumber::SENTRY:
|
||||
return 0;
|
||||
case ArmorNumber::NO1:
|
||||
return 1;
|
||||
case ArmorNumber::NO2:
|
||||
return 2;
|
||||
case ArmorNumber::NO3:
|
||||
return 3;
|
||||
case ArmorNumber::NO4:
|
||||
return 4;
|
||||
case ArmorNumber::NO5:
|
||||
return 5;
|
||||
case ArmorNumber::OUTPOST:
|
||||
return 6;
|
||||
case ArmorNumber::BASE:
|
||||
return 7;
|
||||
case ArmorNumber::UNKNOWN:
|
||||
return 8;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
ArmorNumber armorNumberFromString(const std::string& s) noexcept {
|
||||
if (s == "SENTRY")
|
||||
return ArmorNumber::SENTRY;
|
||||
if (s == "BASE")
|
||||
return ArmorNumber::BASE;
|
||||
if (s == "OUTPOST")
|
||||
return ArmorNumber::OUTPOST;
|
||||
if (s == "NO1")
|
||||
return ArmorNumber::NO1;
|
||||
if (s == "NO2")
|
||||
return ArmorNumber::NO2;
|
||||
if (s == "NO3")
|
||||
return ArmorNumber::NO3;
|
||||
if (s == "NO4")
|
||||
return ArmorNumber::NO4;
|
||||
if (s == "NO5")
|
||||
return ArmorNumber::NO5;
|
||||
return ArmorNumber::UNKNOWN;
|
||||
}
|
||||
|
||||
std::string armorNumberToString(const ArmorNumber& num) noexcept {
|
||||
switch (num) {
|
||||
case ArmorNumber::SENTRY:
|
||||
return "SENTRY";
|
||||
case ArmorNumber::BASE:
|
||||
return "BASE";
|
||||
case ArmorNumber::OUTPOST:
|
||||
return "OUTPOST";
|
||||
case ArmorNumber::NO1:
|
||||
return "NO1";
|
||||
case ArmorNumber::NO2:
|
||||
return "NO2";
|
||||
case ArmorNumber::NO3:
|
||||
return "NO3";
|
||||
case ArmorNumber::NO4:
|
||||
return "NO4";
|
||||
case ArmorNumber::NO5:
|
||||
return "NO5";
|
||||
default:
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
namespace {
|
||||
std::unordered_map<std::string, int> armor_map;
|
||||
std::unordered_map<int, std::vector<ArmorNumber>> tracker_to_armors;
|
||||
bool loaded = false;
|
||||
|
||||
void loadArmorMapOnce() {
|
||||
if (loaded)
|
||||
return;
|
||||
|
||||
try {
|
||||
YAML::Node config = YAML::LoadFile(AUTO_AIM_CONFIG)["armor_map"];
|
||||
|
||||
for (auto it = config.begin(); it != config.end(); ++it) {
|
||||
const std::string key = it->first.as<std::string>();
|
||||
const int tracker_id = it->second.as<int>();
|
||||
|
||||
ArmorNumber armor_num = armorNumberFromString(key);
|
||||
|
||||
armor_map[key] = tracker_id;
|
||||
tracker_to_armors[tracker_id].emplace_back(armor_num);
|
||||
}
|
||||
loaded = true;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "[ArmorMap] Failed to load armor_map.yaml: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int retypetotracker(const ArmorNumber& a) noexcept {
|
||||
loadArmorMapOnce();
|
||||
|
||||
const std::string key = armorNumberToString(a);
|
||||
const auto it = armor_map.find(key);
|
||||
if (it != armor_map.end())
|
||||
return it->second;
|
||||
|
||||
std::cerr << "[retypetotracker] Invalid ArmorNumber: " << static_cast<int>(a) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept {
|
||||
return retypetotracker(a) == retypetotracker(b);
|
||||
}
|
||||
|
||||
std::string armorTypeToString(const ArmorType& type) noexcept {
|
||||
switch (type) {
|
||||
case ArmorType::SMALL:
|
||||
return "small";
|
||||
case ArmorType::LARGE:
|
||||
return "large";
|
||||
default:
|
||||
return "invalid";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<cv::Point2f> ArmorObject::toPts() const noexcept {
|
||||
if (is_ok) {
|
||||
return { lights[0].top, lights[0].bottom, lights[1].bottom, lights[1].top };
|
||||
} else {
|
||||
return { pts[0], pts[1], pts[2], pts[3] };
|
||||
}
|
||||
}
|
||||
bool ArmorObject::checkOkptsRight(double max_error) const noexcept {
|
||||
double error = 0.0;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
error += cv::norm(pts[i] - toPts()[i]);
|
||||
}
|
||||
return error < max_error;
|
||||
}
|
||||
std::array<cv::Point2f, 4> ArmorObject::sortCorners(const std::vector<cv::Point2f>& pts
|
||||
) const noexcept {
|
||||
std::array<cv::Point2f, 4> ordered;
|
||||
|
||||
// 先按 x 坐标分成左右两组
|
||||
std::vector<cv::Point2f> left, right;
|
||||
std::vector<cv::Point2f> sorted = pts;
|
||||
|
||||
std::sort(sorted.begin(), sorted.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
|
||||
return a.x < b.x;
|
||||
});
|
||||
|
||||
left.push_back(sorted[0]);
|
||||
left.push_back(sorted[1]);
|
||||
right.push_back(sorted[2]);
|
||||
right.push_back(sorted[3]);
|
||||
|
||||
// 左边两个点,按 y 分为上/下
|
||||
std::sort(left.begin(), left.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
|
||||
return a.y < b.y;
|
||||
});
|
||||
ordered[1] = left[0]; // 左上
|
||||
ordered[0] = left[1]; // 左下
|
||||
|
||||
// 右边两个点,按 y 分为上/下
|
||||
std::sort(right.begin(), right.end(), [](const cv::Point2f& a, const cv::Point2f& b) {
|
||||
return a.y < b.y;
|
||||
});
|
||||
ordered[2] = right[0]; // 右上
|
||||
ordered[3] = right[1]; // 右下
|
||||
|
||||
return ordered; // 顺序: 左下, 左上, 右上, 右下
|
||||
}
|
||||
std::vector<cv::Point2f> ArmorObject::landmarks() const noexcept {
|
||||
if constexpr (N_LANDMARKS == 4) {
|
||||
if (is_ok) {
|
||||
return { lights[0].bottom, lights[0].top, lights[1].top, lights[1].bottom };
|
||||
} else {
|
||||
const auto ordered = sortCorners(pts);
|
||||
return { ordered[0], ordered[1], ordered[2], ordered[3] };
|
||||
}
|
||||
|
||||
} else {
|
||||
if (is_ok) {
|
||||
return { lights[0].bottom, lights[0].center, lights[0].top,
|
||||
lights[1].top, lights[1].center, lights[1].bottom };
|
||||
} else {
|
||||
const auto ordered = sortCorners(pts);
|
||||
return { ordered[0], (ordered[0] + ordered[1]) / 2.0, ordered[1],
|
||||
ordered[2], (ordered[2] + ordered[3]) / 2.0, ordered[3] };
|
||||
}
|
||||
}
|
||||
}
|
||||
ArmorObject::ArmorObject(const Light& l1, const Light& l2) {
|
||||
pts.resize(4);
|
||||
if (l1.center.x < l2.center.x) {
|
||||
lights.push_back(l1);
|
||||
lights.push_back(l2);
|
||||
pts[0] = l1.top;
|
||||
pts[1] = l1.bottom;
|
||||
pts[2] = l2.bottom;
|
||||
pts[3] = l2.top;
|
||||
} else {
|
||||
lights.push_back(l2);
|
||||
lights.push_back(l1);
|
||||
pts[0] = l2.top;
|
||||
pts[1] = l2.bottom;
|
||||
pts[2] = l1.bottom;
|
||||
pts[3] = l1.top;
|
||||
}
|
||||
is_ok = true;
|
||||
}
|
||||
|
||||
std::vector<cv::Point2f> Armor::toPtsDebug(
|
||||
const cv::Mat& camera_intrinsic,
|
||||
const cv::Mat& camera_distortion
|
||||
) const noexcept {
|
||||
std::vector<cv::Point2f> image_points;
|
||||
const std::vector<cv::Point3f>* model_points;
|
||||
static std::vector<cv::Point3f> SMALL_ARMOR_3D_POINTS_BLOCK = {
|
||||
{ 0, 0.025, -0.066 }, // 左上前
|
||||
{ 0, -0.025, -0.066 }, // 左下前
|
||||
{ 0, -0.025, 0.066 }, // 右下前
|
||||
{ 0, 0.025, 0.066 }, // 右上前
|
||||
{ 0.015, 0.025, -0.066 }, // 左上后
|
||||
{ 0.015, -0.025, -0.066 }, // 左下后
|
||||
{ 0.015, -0.025, 0.066 }, // 右下后
|
||||
{ 0.015, 0.025, 0.066 }, // 右上后
|
||||
};
|
||||
|
||||
static std::vector<cv::Point3f> BIG_ARMOR_3D_POINTS_BLOCK = {
|
||||
{ 0, 0.025, -0.1125 }, { 0, -0.025, -0.1125 }, { 0, -0.025, 0.1125 },
|
||||
{ 0, 0.025, 0.1125 }, { 0.015, 0.025, -0.1125 }, { 0.015, -0.025, -0.1125 },
|
||||
{ 0.015, -0.025, 0.1125 }, { 0.015, 0.025, 0.1125 },
|
||||
};
|
||||
|
||||
if (type == "large") {
|
||||
model_points = &BIG_ARMOR_3D_POINTS_BLOCK;
|
||||
} else if (type == "small") {
|
||||
model_points = &SMALL_ARMOR_3D_POINTS_BLOCK;
|
||||
}
|
||||
const Eigen::Matrix3d tf_rot = target_ori.toRotationMatrix();
|
||||
const cv::Mat rot_mat =
|
||||
(cv::Mat_<double>(3, 3) << tf_rot(0, 0),
|
||||
tf_rot(0, 1),
|
||||
tf_rot(0, 2),
|
||||
tf_rot(1, 0),
|
||||
tf_rot(1, 1),
|
||||
tf_rot(1, 2),
|
||||
tf_rot(2, 0),
|
||||
tf_rot(2, 1),
|
||||
tf_rot(2, 2));
|
||||
|
||||
// 旋转矩阵 -> 旋转向量
|
||||
cv::Mat rvec;
|
||||
cv::Rodrigues(rot_mat, rvec);
|
||||
|
||||
// 平移向量
|
||||
const cv::Mat tvec =
|
||||
(cv::Mat_<double>(3, 1) << target_pos.x(), target_pos.y(), target_pos.z());
|
||||
|
||||
// 反投影
|
||||
cv::projectPoints(
|
||||
*model_points,
|
||||
rvec,
|
||||
tvec,
|
||||
camera_intrinsic,
|
||||
camera_distortion,
|
||||
image_points
|
||||
);
|
||||
return image_points;
|
||||
}
|
||||
|
||||
void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept {
|
||||
for (auto& armor: armors.armors) {
|
||||
transformArmorData(armor, T_camera_to_odom);
|
||||
}
|
||||
}
|
||||
void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept {
|
||||
try {
|
||||
// 位置
|
||||
const Eigen::Vector3d pos_camera = armor.pos;
|
||||
armor.target_pos = utils::transformPosition(pos_camera, T_camera_to_odom);
|
||||
|
||||
// 姿态
|
||||
const Eigen::Quaterniond
|
||||
q_camera(armor.ori.w(), armor.ori.x(), armor.ori.y(), armor.ori.z());
|
||||
const Eigen::Quaterniond q_odom =
|
||||
utils::transformOrientation(q_camera, T_camera_to_odom);
|
||||
armor.target_ori = q_odom;
|
||||
|
||||
// 提取 yaw
|
||||
const Eigen::Vector3d euler = q_odom.toRotationMatrix().eulerAngles(2, 1, 0); // ZYX
|
||||
armor.yaw = euler[0]; // yaw
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
WUST_ERROR("tf") << "Error in camera-to-odom transform: " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
172
wust_vision-main/tasks/auto_aim/type.hpp
Normal file
172
wust_vision-main/tasks/auto_aim/type.hpp
Normal file
@@ -0,0 +1,172 @@
|
||||
#pragma once
|
||||
#include "tasks/type_common.hpp"
|
||||
#include "tasks/utils/utils.hpp"
|
||||
namespace wust_vision {
|
||||
|
||||
namespace auto_aim {
|
||||
constexpr double SMALL_ARMOR_WIDTH = 133.0 / 1000.0; // 135
|
||||
constexpr double SMALL_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
|
||||
constexpr double LARGE_ARMOR_WIDTH = 225.0 / 1000.0;
|
||||
constexpr double LARGE_ARMOR_HEIGHT = 50.0 / 1000.0; // 55
|
||||
|
||||
constexpr double FIFTTEN_DEGREE_RAD = 15 * CV_PI / 180;
|
||||
struct Light: public cv::RotatedRect {
|
||||
Light() = default;
|
||||
|
||||
explicit Light(const std::vector<cv::Point>& contour);
|
||||
|
||||
void addOffset(const cv::Point2f& offset) noexcept;
|
||||
void transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept;
|
||||
|
||||
cv::Point2f top, bottom;
|
||||
int color = 0;
|
||||
cv::Point2f axis;
|
||||
double length = 0;
|
||||
double width = 0;
|
||||
float tilt_angle = 0;
|
||||
};
|
||||
|
||||
enum class ArmorColor : int { BLUE = 0, RED, NONE, PURPLE };
|
||||
|
||||
int formArmorColor(const ArmorColor& color) noexcept;
|
||||
enum class ArmorNumber : int { SENTRY = 0, NO1, NO2, NO3, NO4, NO5, OUTPOST, BASE, UNKNOWN };
|
||||
std::ostream& operator<<(std::ostream& os, const ArmorNumber& number) noexcept;
|
||||
|
||||
int formArmorNumber(const ArmorNumber& number) noexcept;
|
||||
std::string armorNumberToString(const ArmorNumber& num) noexcept;
|
||||
ArmorNumber armorNumberFromString(const std::string& s) noexcept;
|
||||
int retypetotracker(const ArmorNumber& a) noexcept;
|
||||
bool isSameTarget(const ArmorNumber& a, const ArmorNumber& b) noexcept;
|
||||
enum class ArmorsNum { NORMAL_4 = 4, OUTPOST_3 = 3 };
|
||||
|
||||
enum class ArmorType { SMALL, LARGE, INVALID };
|
||||
std::string armorTypeToString(const ArmorType& type) noexcept;
|
||||
|
||||
struct ArmorObject {
|
||||
ArmorColor color;
|
||||
ArmorNumber number;
|
||||
std::vector<cv::Point2f> pts;
|
||||
cv::Rect box;
|
||||
|
||||
cv::Mat number_img;
|
||||
|
||||
double confidence;
|
||||
|
||||
cv::Mat whole_binary_img;
|
||||
cv::Mat whole_rgb_img;
|
||||
cv::Mat whole_gray_img;
|
||||
|
||||
std::vector<Light> lights;
|
||||
cv::Point2f local_offset;
|
||||
cv::Point2f center;
|
||||
bool is_ok = false;
|
||||
ArmorType type;
|
||||
static constexpr const int N_LANDMARKS = 6;
|
||||
static constexpr const int N_LANDMARKS_2 = N_LANDMARKS * 2;
|
||||
|
||||
template<typename PointType>
|
||||
static std::vector<PointType> buildObjectPoints(const double& w, const double& h) noexcept {
|
||||
if constexpr (N_LANDMARKS == 4) {
|
||||
return {
|
||||
PointType(0, w / 2, -h / 2), // 右下
|
||||
PointType(0, w / 2, h / 2), // 右上
|
||||
PointType(0, -w / 2, h / 2), // 左上
|
||||
PointType(0, -w / 2, -h / 2) // 左下
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
PointType(0, w / 2, -h / 2), // 右下
|
||||
PointType(0, w / 2, 0.0), // 右中
|
||||
PointType(0, w / 2, h / 2), // 右上
|
||||
|
||||
PointType(0, -w / 2, h / 2), // 左上
|
||||
PointType(0, -w / 2, 0.0), // 左中
|
||||
PointType(0, -w / 2, -h / 2) // 左下
|
||||
};
|
||||
}
|
||||
}
|
||||
template<typename IDType>
|
||||
static std::vector<std::pair<IDType, IDType>> buildSymPairs() noexcept {
|
||||
if constexpr (N_LANDMARKS == 4) {
|
||||
static const std::vector<std::pair<IDType, IDType>> pairs = {
|
||||
{ 0, 3 },
|
||||
{ 1, 2 },
|
||||
// { 0, 2 },
|
||||
// { 1, 3 }
|
||||
};
|
||||
return pairs;
|
||||
} else {
|
||||
static const std::vector<std::pair<IDType, IDType>> pairs = {
|
||||
{ 0, 5 },
|
||||
{ 1, 4 },
|
||||
{ 2, 3 },
|
||||
// { 0, 3 },
|
||||
// { 2, 5 }
|
||||
|
||||
};
|
||||
return pairs;
|
||||
}
|
||||
}
|
||||
std::vector<cv::Point2f> toPts() const noexcept;
|
||||
bool checkOkptsRight(double max_error) const noexcept;
|
||||
std::array<cv::Point2f, 4> sortCorners(const std::vector<cv::Point2f>& pts) const noexcept;
|
||||
|
||||
// Landmarks start from bottom left in clockwise order
|
||||
std::vector<cv::Point2f> landmarks() const noexcept;
|
||||
void addOffset(const cv::Point2f& offset) noexcept {
|
||||
for (auto& pt: pts) {
|
||||
pt += offset;
|
||||
}
|
||||
center += offset;
|
||||
box.x += offset.x;
|
||||
box.y += offset.y;
|
||||
for (auto& l: lights) {
|
||||
l.addOffset(offset);
|
||||
}
|
||||
}
|
||||
void transform(const Eigen::Matrix<float, 3, 3>& transform_matrix) noexcept {
|
||||
for (auto& l: lights) {
|
||||
l.transform(transform_matrix);
|
||||
}
|
||||
center = utils::transformPoint2D(transform_matrix, center);
|
||||
box = utils::transformRect(transform_matrix, box);
|
||||
for (auto& pt: pts) {
|
||||
pt = utils::transformPoint2D(transform_matrix, pt);
|
||||
}
|
||||
}
|
||||
ArmorObject(const Light& l1, const Light& l2);
|
||||
ArmorObject() = default;
|
||||
};
|
||||
|
||||
struct Armor {
|
||||
public:
|
||||
ArmorNumber number;
|
||||
std::string type;
|
||||
Eigen::Vector3d pos;
|
||||
Eigen::Quaterniond ori;
|
||||
Eigen::Vector3d target_pos;
|
||||
Eigen::Quaterniond target_ori;
|
||||
float distance_to_image_center;
|
||||
float yaw;
|
||||
std::chrono::steady_clock::time_point timestamp;
|
||||
bool is_ok = false;
|
||||
bool is_none_purple = false;
|
||||
int id = -1;
|
||||
std::vector<cv::Point2f> toPtsDebug(
|
||||
const cv::Mat& camera_intrinsic,
|
||||
const cv::Mat& camera_distortion
|
||||
) const noexcept;
|
||||
};
|
||||
struct Armors {
|
||||
public:
|
||||
std::vector<Armor> armors;
|
||||
std::chrono::steady_clock::time_point timestamp;
|
||||
int id;
|
||||
Eigen::Vector3d v;
|
||||
};
|
||||
|
||||
void transformArmorData(Armors& armors, Eigen::Matrix4d T_camera_to_odom) noexcept;
|
||||
void transformArmorData(Armor& armor, const Eigen::Matrix4d& T_camera_to_odom) noexcept;
|
||||
|
||||
} // namespace auto_aim
|
||||
} // namespace wust_vision
|
||||
Reference in New Issue
Block a user