Test calc_position

This commit is contained in:
Torbjørn Ludvigsen 2025-12-01 20:15:04 +01:00
parent 90400bda3f
commit 1a7d858e2e
3 changed files with 349 additions and 12 deletions

View file

@ -69,7 +69,18 @@ class SolverSummary:
SOLVERS = {
"quadratic": {"path": BUILD / "solver_quadratic", "supports_no_flex": False},
"quadratic": {
"path": BUILD / "solver_quadratic",
"supports_no_flex": False,
"label": "quadratic (reference)",
"type": "binary",
},
"calc_position": {
"path": BUILD / "solver_calc_position",
"supports_no_flex": False,
"label": "calc_position",
"type": "binary",
},
}
@ -349,6 +360,58 @@ def summarise(samples: Sequence[Sample], results: Sequence[SolverResult]) -> Sol
return SolverSummary(overall=overall, by_geo=by_geo)
def summarise_difference(samples: Sequence[Sample], ref: Sequence[SolverResult], alt: Sequence[SolverResult]) -> SolverSummary:
records_by_geo: dict[str, list[tuple[bool, float | None, float, int, float]]] = defaultdict(list)
skipped_by_geo: dict[str, int] = defaultdict(int)
skipped_total = 0
for sample, r_ref, r_alt in zip(samples, ref, alt):
geo = sample.geometry
if r_ref.unsupported or r_alt.unsupported or sample.unsupported:
skipped_total += 1
skipped_by_geo[geo] += 1
continue
if not (r_ref.ok and r_alt.ok):
records_by_geo[geo].append((False, None, 0.0, 0, 0.0))
continue
dx = r_ref.pos[0] - r_alt.pos[0]
dy = r_ref.pos[1] - r_alt.pos[1]
dz = r_ref.pos[2] - r_alt.pos[2]
err = math.sqrt(dx * dx + dy * dy + dz * dz)
ok = err <= MAX_SUCCESS_ERR_MM
cost = abs(r_ref.cost - r_alt.cost)
iters = abs(r_ref.iterations - r_alt.iterations)
dt_ms = abs(r_ref.runtime_ms - r_alt.runtime_ms)
records_by_geo[geo].append((ok, err, cost, iters, dt_ms))
def build_stats(records: list[tuple[bool, float | None, float, int, float]], skipped: int) -> Stats:
supported = len(records)
success = sum(1 for r in records if r[0])
errs = [r[1] for r in records if r[0] and r[1] is not None]
costs = [r[2] for r in records if r[0]]
iters = [r[3] for r in records if r[0]]
runtimes = [r[4] for r in records if r[0]]
return Stats(
total=supported + skipped,
supported=supported,
success_rate=(success / supported) * 100 if supported else 0.0,
mae=statistics.mean(errs) if errs else float("nan"),
med_err=statistics.median(errs) if errs else float("nan"),
std_err=statistics.pstdev(errs) if errs else float("nan"),
mean_cost=statistics.mean(costs) if costs else float("nan"),
mean_iters=statistics.mean(iters) if iters else float("nan"),
mean_ms=statistics.mean(runtimes) if runtimes else float("nan"),
skipped=skipped,
)
overall_records: list[tuple[bool, float | None, float, int, float]] = []
for recs in records_by_geo.values():
overall_records.extend(recs)
overall = build_stats(overall_records, skipped_total)
by_geo = {geo: build_stats(recs, skipped_by_geo.get(geo, 0)) for geo, recs in records_by_geo.items()}
return SolverSummary(overall=overall, by_geo=by_geo)
def print_summary(title: str, summaries: dict):
print(f"\n{title}")
for solver, stats in summaries.items():
@ -374,9 +437,21 @@ def run_suite(name: str, sample_patterns: Iterable[str], use_flex: bool = True,
combined.extend(load_samples(pat))
samples = combined
summaries = {}
for solver in SOLVERS:
results = run_solver(samples, solver, use_flex=use_flex)
summaries[solver + (" (noflex)" if not use_flex and SOLVERS[solver]["supports_no_flex"] else "")] = summarise(samples, results)
results_store: dict[str, List[SolverResult]] = {}
for key, meta in SOLVERS.items():
results = run_solver(samples, key, use_flex=use_flex)
label = meta["label"]
if not use_flex and meta["supports_no_flex"]:
label = f"{label} (noflex)"
summaries[label] = summarise(samples, results)
results_store[label] = results
# Pairwise difference between reference and calc_position if available
ref_label = SOLVERS["quadratic"]["label"]
alt_label = SOLVERS.get("calc_position", {}).get("label")
if ref_label in results_store and alt_label in results_store:
diff_stats = summarise_difference(samples, results_store[ref_label], results_store[alt_label])
summaries["difference"] = diff_stats
print_summary(name, summaries)
return summaries
@ -406,15 +481,16 @@ def main() -> int:
run_suite("Per-line bias (+/-5 mm)", [], samples=biased)
# Flex vs no-flex (Pott supports the toggle; other solvers keep flex enabled)
pott_only = ["pott"]
flex_summaries = {}
for use_flex in (True, False):
summaries = {}
for solver in SOLVERS:
if not use_flex and not SOLVERS[solver]["supports_no_flex"]:
for key, meta in SOLVERS.items():
if not use_flex and not meta["supports_no_flex"]:
continue
results = run_solver(clean_samples, solver, use_flex=use_flex)
label = solver if use_flex or solver != "pott" else f"{solver} (noflex)"
results = run_solver(clean_samples, key, use_flex=use_flex)
label = meta["label"]
if not use_flex and meta["supports_no_flex"]:
label = f"{label} (noflex)"
summaries[label] = summarise(clean_samples, results)
print_summary(f"Flex toggle (clean baseline, use_flex={use_flex})", summaries)
flex_summaries.update(summaries)
@ -423,9 +499,9 @@ def main() -> int:
perf_samples = load_samples("clean_baseline_*.jsonl") + load_samples("larger_baseline_*.jsonl")
perf_samples = perf_samples[:1000]
perf_summaries = {}
for solver in SOLVERS:
results = run_solver(perf_samples, solver, use_flex=True)
perf_summaries[solver] = summarise(perf_samples, results)
for key, meta in SOLVERS.items():
results = run_solver(perf_samples, key, use_flex=True)
perf_summaries[meta["label"]] = summarise(perf_samples, results)
print_summary("Performance microbench (<=1000 samples)", perf_summaries)
return 0

View file

@ -0,0 +1,259 @@
// CLI wrapper around klipper's winch_forward_solve for dataset evaluation.
// Matches the input format used by solver_quadratic: optional CFG line
// followed by rows of: N use_flex motor_deg... anchor_xyz...
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include "../klippy/chelper/kin_winch.c"
// Stub out move_get_coord so kin_winch.c links; not used by this tool.
struct coord
move_get_coord(struct move *m, double move_time)
{
struct coord c = { 0.0, 0.0, 0.0 };
return c;
}
struct cfg_state {
size_t count;
double spool_buildup;
double spring_k_per_unit_length;
double mover_weight;
double spool_gear;
double motor_gear;
double steps_per_rev;
int use_flex;
int ignore_gravity;
int ignore_pretension;
double lambda_reg;
double tol;
int max_iters_target;
double g;
double spool_r[WINCH_MAX_ANCHORS];
double mech_adv[WINCH_MAX_ANCHORS];
double lines_per_spool[WINCH_MAX_ANCHORS];
double min_force[WINCH_MAX_ANCHORS];
double max_force[WINCH_MAX_ANCHORS];
double guy_wires[WINCH_MAX_ANCHORS];
int has_cfg;
};
static void
init_cfg(struct cfg_state *cfg)
{
memset(cfg, 0, sizeof(*cfg));
cfg->count = 0;
cfg->spool_buildup = 0.043003;
cfg->spring_k_per_unit_length = 20000.0;
cfg->mover_weight = 2.0;
cfg->spool_gear = 255.0;
cfg->motor_gear = 20.0;
cfg->steps_per_rev = 360.0;
cfg->use_flex = 1;
cfg->ignore_gravity = 0;
cfg->ignore_pretension = 0;
cfg->lambda_reg = 1e-3;
cfg->tol = 1e-3;
cfg->max_iters_target = 100;
cfg->g = 9.81;
for (int i = 0; i < WINCH_MAX_ANCHORS; ++i) {
cfg->spool_r[i] = 75.0;
cfg->mech_adv[i] = 2.0;
cfg->lines_per_spool[i] = 1.0;
cfg->min_force[i] = 3.0;
cfg->max_force[i] = 120.0;
cfg->guy_wires[i] = 0.0;
}
cfg->has_cfg = 0;
}
static void
broadcast_tail(double *arr, size_t count)
{
if (!count || count >= WINCH_MAX_ANCHORS)
return;
double last = arr[count - 1];
for (size_t i = count; i < WINCH_MAX_ANCHORS; ++i)
arr[i] = last;
}
static int
parse_cfg_line(const char *line, struct cfg_state *cfg)
{
init_cfg(cfg);
char tag[8] = {0};
double steps_per_rev = 0.;
int use_flex = 1, ignore_grav = 0, ignore_pre = 0;
int max_iters = 0;
int nread = sscanf(line,
"%7s %zu %lf %lf %lf %lf %lf %lf %d %d %d %lf %lf %d %lf",
tag, &cfg->count, &cfg->spool_buildup,
&cfg->spring_k_per_unit_length, &cfg->mover_weight,
&cfg->spool_gear, &cfg->motor_gear, &steps_per_rev,
&use_flex, &ignore_grav, &ignore_pre, &cfg->lambda_reg,
&cfg->tol, &max_iters, &cfg->g);
if (nread != 15 || strcmp(tag, "CFG") != 0)
return 0;
cfg->steps_per_rev = steps_per_rev;
cfg->use_flex = use_flex;
cfg->ignore_gravity = ignore_grav;
cfg->ignore_pretension = ignore_pre;
cfg->max_iters_target = max_iters;
size_t limit = cfg->count > WINCH_MAX_ANCHORS ? WINCH_MAX_ANCHORS : cfg->count;
const char *p = line;
for (int i = 0; i < 15; ++i) {
p = strchr(p, ' ');
if (!p)
return 0;
while (*p == ' ')
p++;
}
for (size_t i = 0; i < limit; ++i) {
cfg->spool_r[i] = strtod(p, (char **)&p);
}
for (size_t i = 0; i < limit; ++i) {
cfg->mech_adv[i] = strtod(p, (char **)&p);
}
for (size_t i = 0; i < limit; ++i) {
cfg->lines_per_spool[i] = strtod(p, (char **)&p);
}
for (size_t i = 0; i < limit; ++i) {
cfg->min_force[i] = strtod(p, (char **)&p);
}
for (size_t i = 0; i < limit; ++i) {
cfg->max_force[i] = strtod(p, (char **)&p);
}
for (size_t i = 0; i < limit; ++i) {
cfg->guy_wires[i] = strtod(p, (char **)&p);
}
broadcast_tail(cfg->spool_r, limit);
broadcast_tail(cfg->mech_adv, limit);
broadcast_tail(cfg->lines_per_spool, limit);
broadcast_tail(cfg->min_force, limit);
broadcast_tail(cfg->max_force, limit);
broadcast_tail(cfg->guy_wires, limit);
cfg->has_cfg = 1;
return 1;
}
static double
rotation_distance_for_axis(const struct cfg_state *cfg, size_t idx)
{
double r = cfg->spool_r[idx];
double ma = cfg->mech_adv[idx];
if (ma == 0.)
ma = 1.;
double gear = cfg->spool_gear / cfg->motor_gear;
return (2.0 * M_PI * r) / (gear * ma);
}
static int
solve_sample(const struct cfg_state *cfg, size_t num, int use_flex,
const double *motor_deg, const double *anchors, double *pos_out,
int *iters_out, double *cost_out, double *runtime_ms_out)
{
struct winch_flex *wf = winch_flex_alloc();
if (!wf)
return 0;
double min_force[WINCH_MAX_ANCHORS], max_force[WINCH_MAX_ANCHORS];
double guy[WINCH_MAX_ANCHORS];
int mech_adv[WINCH_MAX_ANCHORS];
for (size_t i = 0; i < num; ++i) {
min_force[i] = cfg->min_force[i];
max_force[i] = cfg->max_force[i];
guy[i] = cfg->guy_wires[i];
mech_adv[i] = (int)(cfg->mech_adv[i] + 0.5);
if (mech_adv[i] <= 0)
mech_adv[i] = 1;
}
winch_flex_configure(
wf, (int)num, anchors, cfg->spool_buildup, cfg->mover_weight,
cfg->spring_k_per_unit_length, min_force, max_force, guy,
WINCH_FORCE_ALGO_QP, cfg->ignore_gravity, cfg->ignore_pretension,
mech_adv);
winch_flex_set_enabled(wf, use_flex ? 1 : 0);
double motor_mm[WINCH_MAX_ANCHORS];
for (size_t i = 0; i < num; ++i) {
double rd = rotation_distance_for_axis(cfg, i);
winch_flex_set_spool_params(wf, (int)i, rd, cfg->steps_per_rev);
motor_mm[i] = motor_deg[i] / cfg->steps_per_rev * rd;
}
struct timespec t0, t1;
clock_gettime(CLOCK_MONOTONIC, &t0);
int iters = 0;
double cost = 0.0;
int ok = winch_forward_solve(
wf, motor_mm, NULL, 1e-3, 1e-3, 3, 30, pos_out, &cost, &iters);
clock_gettime(CLOCK_MONOTONIC, &t1);
double dt_ms = (t1.tv_sec - t0.tv_sec) * 1000.0
+ (t1.tv_nsec - t0.tv_nsec) / 1e6;
winch_flex_free(wf);
if (iters_out)
*iters_out = iters;
if (cost_out)
*cost_out = cost;
if (runtime_ms_out)
*runtime_ms_out = dt_ms;
return ok;
}
int
main(void)
{
char buf[4096];
struct cfg_state cfg;
init_cfg(&cfg);
while (fgets(buf, sizeof(buf), stdin)) {
if (buf[0] == '\0' || buf[0] == '\n')
continue;
if (strncmp(buf, "CFG", 3) == 0) {
parse_cfg_line(buf, &cfg);
continue;
}
size_t num = 0;
int use_flex = 1;
const char *p = buf;
if (sscanf(p, "%zu %d", &num, &use_flex) != 2 || num == 0
|| num > WINCH_MAX_ANCHORS) {
printf("fail 0 0 0 0 0 0\n");
continue;
}
// advance past num/use_flex
for (int i = 0; i < 2; ++i) {
p = strchr(p, ' ');
if (!p)
break;
while (*p == ' ')
p++;
}
double motor_deg[WINCH_MAX_ANCHORS];
for (size_t i = 0; i < num; ++i) {
motor_deg[i] = strtod(p, (char **)&p);
}
double anchors[WINCH_MAX_ANCHORS * 3];
for (size_t i = 0; i < num * 3; ++i) {
anchors[i] = strtod(p, (char **)&p);
}
double pos[3] = {0., 0., 0.};
int iters = 0;
double cost = 0., ms = 0.;
int ok = solve_sample(&cfg, num, use_flex, motor_deg, anchors, pos,
&iters, &cost, &ms);
printf("%s %.9g %.9g %.9g %d %.9g %.9g\n",
ok ? "ok" : "fail", pos[0], pos[1], pos[2], iters, cost, ms);
}
return 0;
}

View file

@ -13,6 +13,8 @@ compile() {
}
compile motorstepstocartesiantest_quadratic.cpp solver_quadratic
echo "Building solver_calc_position from solver_calc_position.c"
gcc -O2 -std=gnu11 -D_GNU_SOURCE -I../klippy/chelper solver_calc_position.c -lm -o build/solver_calc_position
echo "Running solver comparison..."
python3 run_tests.py