From ed8f7c8a62862ab220f69367254084e1b3bcb8fc Mon Sep 17 00:00:00 2001 From: Aria Shrimpton Date: Wed, 20 Mar 2024 14:47:20 +0000 Subject: discard outliers before fitting --- src/crates/candelabra/src/cost/fit.rs | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) (limited to 'src/crates') diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs index a9262b9..a6dbce0 100644 --- a/src/crates/candelabra/src/cost/fit.rs +++ b/src/crates/candelabra/src/cost/fit.rs @@ -2,6 +2,7 @@ //! Based on code from al-jshen: use super::benchmark::Observation; +use log::trace; use na::{Dyn, OVector}; use nalgebra::{dimension, Matrix, VecStorage}; use serde::{Deserialize, Serialize}; @@ -19,8 +20,8 @@ pub type Cost = f64; impl Estimator { /// Fit from the given set of observations, using the least squared method. pub fn fit(results: &[Observation]) -> Self { - // Shift down so that ys start at 0 - let (xs, ys) = Self::to_data(results); + let results = Self::discard_outliers(results); + let (xs, ys) = Self::to_data(&results); let xv = Self::prepare_input(&xs); let xtx = xv.transpose() * xv.clone(); @@ -98,4 +99,28 @@ impl Estimator { (xs, ys) } + + fn discard_outliers(results: &[Observation]) -> Vec { + let mut ns = results.iter().map(|(n, _)| *n).collect::>(); + ns.dedup(); + let mut new_results = Vec::with_capacity(results.len()); + for &n in ns.iter() { + let mut n_results: Vec<_> = results.iter().filter(|(n2, _)| *n2 == n).collect(); + let old_len = n_results.len(); + n_results.sort_by(|(_, x1), (_, x2)| x1.partial_cmp(x2).unwrap()); + + let (_, lq) = n_results[n_results.len() / 4]; + let (_, uq) = n_results[(n_results.len() / 2) + (n_results.len() / 4)]; + + n_results.retain(|(_, x)| *x >= *lq && *x < *uq); + trace!( + "Discarded {} outliers for n = {n}", + old_len - n_results.len() + ); + + new_results.extend(n_results); + } + + new_results + } } -- cgit v1.2.3