aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAria Shrimpton <me@aria.rip>2024-03-20 14:47:20 +0000
committerAria Shrimpton <me@aria.rip>2024-03-20 14:47:20 +0000
commited8f7c8a62862ab220f69367254084e1b3bcb8fc (patch)
treebe896352ccf00c8919362500961fd291c7e191dc
parentd34cd4dd5d6a96c1dda6a7fc11a11bab94601928 (diff)
discard outliers before fitting
-rw-r--r--src/crates/candelabra/src/cost/fit.rs29
1 files changed, 27 insertions, 2 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs
index a9262b9..a6dbce0 100644
--- a/src/crates/candelabra/src/cost/fit.rs
+++ b/src/crates/candelabra/src/cost/fit.rs
@@ -2,6 +2,7 @@
//! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master>
use super::benchmark::Observation;
+use log::trace;
use na::{Dyn, OVector};
use nalgebra::{dimension, Matrix, VecStorage};
use serde::{Deserialize, Serialize};
@@ -19,8 +20,8 @@ pub type Cost = f64;
impl Estimator {
/// Fit from the given set of observations, using the least squared method.
pub fn fit(results: &[Observation]) -> Self {
- // Shift down so that ys start at 0
- let (xs, ys) = Self::to_data(results);
+ let results = Self::discard_outliers(results);
+ let (xs, ys) = Self::to_data(&results);
let xv = Self::prepare_input(&xs);
let xtx = xv.transpose() * xv.clone();
@@ -98,4 +99,28 @@ impl Estimator {
(xs, ys)
}
+
+ fn discard_outliers(results: &[Observation]) -> Vec<Observation> {
+ let mut ns = results.iter().map(|(n, _)| *n).collect::<Vec<_>>();
+ ns.dedup();
+ let mut new_results = Vec::with_capacity(results.len());
+ for &n in ns.iter() {
+ let mut n_results: Vec<_> = results.iter().filter(|(n2, _)| *n2 == n).collect();
+ let old_len = n_results.len();
+ n_results.sort_by(|(_, x1), (_, x2)| x1.partial_cmp(x2).unwrap());
+
+ let (_, lq) = n_results[n_results.len() / 4];
+ let (_, uq) = n_results[(n_results.len() / 2) + (n_results.len() / 4)];
+
+ n_results.retain(|(_, x)| *x >= *lq && *x < *uq);
+ trace!(
+ "Discarded {} outliers for n = {n}",
+ old_len - n_results.len()
+ );
+
+ new_results.extend(n_results);
+ }
+
+ new_results
+ }
}