diff options
author | Aria Shrimpton <me@aria.rip> | 2024-01-19 21:20:03 +0000 |
---|---|---|
committer | Aria Shrimpton <me@aria.rip> | 2024-01-19 21:20:03 +0000 |
commit | 501d58741dc80b41e7f456f89bc0d5b5ede740de (patch) | |
tree | bbc1c72b335e98abacca5bd219e7c2393e6d60b0 /src/crates | |
parent | 5396c0840ee458a1dbc265afa6fe9a00d4156b86 (diff) |
feat(fit): add (unused) pre-transformations to fit
Diffstat (limited to 'src/crates')
-rw-r--r-- | src/crates/candelabra/src/cost/fit.rs | 146 |
1 files changed, 140 insertions, 6 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs index c4b8850..57bee78 100644 --- a/src/crates/candelabra/src/cost/fit.rs +++ b/src/crates/candelabra/src/cost/fit.rs @@ -1,13 +1,23 @@ //! Fitting a 3rd-order polynomial to benchmark results //! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master> +use std::cmp; + use super::benchmark::Observation; use na::{Dyn, MatrixXx4, OVector}; use serde::{Deserialize, Serialize}; /// Estimates durations using a 3rd-order polynomial. +/// Value i is multiplied by x^i #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Estimator(pub [f64; 4]); +pub struct Estimator { + pub coeffs: [f64; 4], + /// (shift, scale) + pub transform_x: (f64, f64), + + /// (shift, scale) + pub transform_y: (f64, f64), +} /// Approximate cost of an action. /// This is an approximation for the number of nanoseconds it would take. @@ -16,7 +26,16 @@ pub type Cost = f64; impl Estimator { /// Fit from the given set of observations, using the least squared method. pub fn fit(results: &[Observation]) -> Self { - let (xs, ys) = Self::to_data(results); + // Shift down so that ys start at 0 + let (mut xs, mut ys) = Self::to_data(results); + + let transform_x = Self::normalisation_transformation(xs.iter()); + let transform_y = Self::normalisation_transformation(ys.iter()); + + xs.iter_mut() + .for_each(|e| *e = (*e + transform_x.0) / transform_x.1); + ys.iter_mut() + .for_each(|e| *e = (*e + transform_y.0) / transform_y.1); let xv = vandermonde(&xs); let xtx = xv.transpose() * xv.clone(); @@ -24,7 +43,38 @@ impl Estimator { let xty = xv.transpose() * ys; let coeffs = xtxinv * xty; - Self(coeffs.into()) + Self { + coeffs: coeffs.into(), + transform_x, + transform_y, + } + } + + pub fn normalisation_transformation<'a, I>(is: I) -> (f64, f64) + where + I: Iterator<Item = &'a f64>, + { + // let (min, max) = is.fold((f64::MAX, f64::MIN), |(min, max), f| { + // (min.min(*f), max.max(*f)) + // }); + // let shift = -min; + // let mut scale = 10.0 / (max - min); + // if !scale.is_normal() || scale.abs() - 1e-10 < 0.0 { + // scale = 1.0; + // } + // (-min, scale) + + (0.0, 1.0) + } + + /// Get the mean squared error with respect to some data points + pub fn mse(&self, results: &[Observation]) { + let (xs, ys) = Self::to_data(results); + xs.iter() + .zip(ys.iter()) + .map(|(x, y)| (y - self.estimatef(y)).powi(2)) + .sum() + / xs.len() } /// Estimate the cost of a given operation at the given `n`. @@ -33,9 +83,11 @@ impl Estimator { } /// Estimate the cost of a given operation at the given `n`. - pub fn estimatef(&self, n: f64) -> Cost { - let [a, b, c, d] = self.0; - a + b * n + c * n.powi(2) + d * n.powi(3) + pub fn estimatef(&self, mut n: f64) -> Cost { + let [a, b, c, d] = self.coeffs; + n = (n + self.transform_x.0) * self.transform_x.1; + let raw = a + b * n + c * n.powi(2) + d * n.powi(3); + (raw / self.transform_y.1) - self.transform_y.0 } /// Convert a list of observations to the format we use internally. @@ -66,3 +118,85 @@ fn vandermonde(xs: &[f64]) -> MatrixXx4<f64> { mat } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use crate::cost::{benchmark::Observation, BenchmarkResult, Estimator}; + + const EPSILON: f64 = 0.1e-3; + + fn create_observations(points: &[(usize, u64)]) -> Vec<Observation> { + points + .iter() + .map(|(n, p)| { + ( + *n, + BenchmarkResult { + min: Duration::from_nanos(*p), + max: Duration::from_nanos(*p), + avg: Duration::from_nanos(*p), + }, + ) + }) + .collect() + } + + fn assert_close_fit(points: &[(usize, u64)], msg: &'static str) { + let data = create_observations(points); + let estimator = Estimator::fit(&data); + let mse = estimator.mse(&data); + dbg!(&estimator, mse); + + assert!(rss.abs() < EPSILON, "{} has too high mse", msg); + } + + #[test] + fn test_fit_basic() { + assert_close_fit(&[(1, 1), (2, 1), (3, 1), (4, 1)], "constant"); + assert_close_fit(&[(1, 1), (2, 2), (3, 3), (4, 4)], "straight line"); + assert_close_fit(&[(1, 1), (2, 4), (3, 9), (4, 16)], "square"); + assert_close_fit(&[(1, 1), (2, 8), (3, 27), (4, 64)], "cubic"); + } + + #[test] + fn test_fit_basic_largenum() { + assert_close_fit( + &[ + (100_000, 100_000), + (200_000, 100_000), + (300_000, 100_000), + (400_000, 100_000), + ], + "constant", + ); + assert_close_fit( + &[ + (100_000, 100_000), + (200_000, 200_000), + (300_000, 300_000), + (400_000, 400_000), + ], + "straight line", + ); + assert_close_fit( + &[ + (100_000, 100_000), + (200_000, 400_000), + (300_000, 900_000), + (400_000, 1_600_000), + ], + "square", + ); + assert_close_fit( + &[ + (100_000, 100_000), + (200_000, 800_000), + (300_000, 2_700_000), + (400_000, 6_400_000), + ], + "cubic", + ); + } +} |