diff options
author | Aria Shrimpton <me@aria.rip> | 2024-03-20 14:47:20 +0000 |
---|---|---|
committer | Aria Shrimpton <me@aria.rip> | 2024-03-20 14:47:20 +0000 |
commit | ed8f7c8a62862ab220f69367254084e1b3bcb8fc (patch) | |
tree | be896352ccf00c8919362500961fd291c7e191dc | |
parent | d34cd4dd5d6a96c1dda6a7fc11a11bab94601928 (diff) |
discard outliers before fitting
-rw-r--r-- | src/crates/candelabra/src/cost/fit.rs | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs index a9262b9..a6dbce0 100644 --- a/src/crates/candelabra/src/cost/fit.rs +++ b/src/crates/candelabra/src/cost/fit.rs @@ -2,6 +2,7 @@ //! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master> use super::benchmark::Observation; +use log::trace; use na::{Dyn, OVector}; use nalgebra::{dimension, Matrix, VecStorage}; use serde::{Deserialize, Serialize}; @@ -19,8 +20,8 @@ pub type Cost = f64; impl Estimator { /// Fit from the given set of observations, using the least squared method. pub fn fit(results: &[Observation]) -> Self { - // Shift down so that ys start at 0 - let (xs, ys) = Self::to_data(results); + let results = Self::discard_outliers(results); + let (xs, ys) = Self::to_data(&results); let xv = Self::prepare_input(&xs); let xtx = xv.transpose() * xv.clone(); @@ -98,4 +99,28 @@ impl Estimator { (xs, ys) } + + fn discard_outliers(results: &[Observation]) -> Vec<Observation> { + let mut ns = results.iter().map(|(n, _)| *n).collect::<Vec<_>>(); + ns.dedup(); + let mut new_results = Vec::with_capacity(results.len()); + for &n in ns.iter() { + let mut n_results: Vec<_> = results.iter().filter(|(n2, _)| *n2 == n).collect(); + let old_len = n_results.len(); + n_results.sort_by(|(_, x1), (_, x2)| x1.partial_cmp(x2).unwrap()); + + let (_, lq) = n_results[n_results.len() / 4]; + let (_, uq) = n_results[(n_results.len() / 2) + (n_results.len() / 4)]; + + n_results.retain(|(_, x)| *x >= *lq && *x < *uq); + trace!( + "Discarded {} outliers for n = {n}", + old_len - n_results.len() + ); + + new_results.extend(n_results); + } + + new_results + } } |