aboutsummaryrefslogtreecommitdiff
path: root/src/crates
diff options
context:
space:
mode:
authorAria Shrimpton <me@aria.rip>2024-01-19 21:20:03 +0000
committerAria Shrimpton <me@aria.rip>2024-01-19 21:20:03 +0000
commit501d58741dc80b41e7f456f89bc0d5b5ede740de (patch)
treebbc1c72b335e98abacca5bd219e7c2393e6d60b0 /src/crates
parent5396c0840ee458a1dbc265afa6fe9a00d4156b86 (diff)
feat(fit): add (unused) pre-transformations to fit
Diffstat (limited to 'src/crates')
-rw-r--r--src/crates/candelabra/src/cost/fit.rs146
1 files changed, 140 insertions, 6 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs
index c4b8850..57bee78 100644
--- a/src/crates/candelabra/src/cost/fit.rs
+++ b/src/crates/candelabra/src/cost/fit.rs
@@ -1,13 +1,23 @@
//! Fitting a 3rd-order polynomial to benchmark results
//! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master>
+use std::cmp;
+
use super::benchmark::Observation;
use na::{Dyn, MatrixXx4, OVector};
use serde::{Deserialize, Serialize};
/// Estimates durations using a 3rd-order polynomial.
+/// Value i is multiplied by x^i
#[derive(Debug, Clone, Deserialize, Serialize)]
-pub struct Estimator(pub [f64; 4]);
+pub struct Estimator {
+ pub coeffs: [f64; 4],
+ /// (shift, scale)
+ pub transform_x: (f64, f64),
+
+ /// (shift, scale)
+ pub transform_y: (f64, f64),
+}
/// Approximate cost of an action.
/// This is an approximation for the number of nanoseconds it would take.
@@ -16,7 +26,16 @@ pub type Cost = f64;
impl Estimator {
/// Fit from the given set of observations, using the least squared method.
pub fn fit(results: &[Observation]) -> Self {
- let (xs, ys) = Self::to_data(results);
+ // Shift down so that ys start at 0
+ let (mut xs, mut ys) = Self::to_data(results);
+
+ let transform_x = Self::normalisation_transformation(xs.iter());
+ let transform_y = Self::normalisation_transformation(ys.iter());
+
+ xs.iter_mut()
+ .for_each(|e| *e = (*e + transform_x.0) / transform_x.1);
+ ys.iter_mut()
+ .for_each(|e| *e = (*e + transform_y.0) / transform_y.1);
let xv = vandermonde(&xs);
let xtx = xv.transpose() * xv.clone();
@@ -24,7 +43,38 @@ impl Estimator {
let xty = xv.transpose() * ys;
let coeffs = xtxinv * xty;
- Self(coeffs.into())
+ Self {
+ coeffs: coeffs.into(),
+ transform_x,
+ transform_y,
+ }
+ }
+
+ pub fn normalisation_transformation<'a, I>(is: I) -> (f64, f64)
+ where
+ I: Iterator<Item = &'a f64>,
+ {
+ // let (min, max) = is.fold((f64::MAX, f64::MIN), |(min, max), f| {
+ // (min.min(*f), max.max(*f))
+ // });
+ // let shift = -min;
+ // let mut scale = 10.0 / (max - min);
+ // if !scale.is_normal() || scale.abs() - 1e-10 < 0.0 {
+ // scale = 1.0;
+ // }
+ // (-min, scale)
+
+ (0.0, 1.0)
+ }
+
+ /// Get the mean squared error with respect to some data points
+ pub fn mse(&self, results: &[Observation]) {
+ let (xs, ys) = Self::to_data(results);
+ xs.iter()
+ .zip(ys.iter())
+ .map(|(x, y)| (y - self.estimatef(y)).powi(2))
+ .sum()
+ / xs.len()
}
/// Estimate the cost of a given operation at the given `n`.
@@ -33,9 +83,11 @@ impl Estimator {
}
/// Estimate the cost of a given operation at the given `n`.
- pub fn estimatef(&self, n: f64) -> Cost {
- let [a, b, c, d] = self.0;
- a + b * n + c * n.powi(2) + d * n.powi(3)
+ pub fn estimatef(&self, mut n: f64) -> Cost {
+ let [a, b, c, d] = self.coeffs;
+ n = (n + self.transform_x.0) * self.transform_x.1;
+ let raw = a + b * n + c * n.powi(2) + d * n.powi(3);
+ (raw / self.transform_y.1) - self.transform_y.0
}
/// Convert a list of observations to the format we use internally.
@@ -66,3 +118,85 @@ fn vandermonde(xs: &[f64]) -> MatrixXx4<f64> {
mat
}
+
+#[cfg(test)]
+mod tests {
+ use std::time::Duration;
+
+ use crate::cost::{benchmark::Observation, BenchmarkResult, Estimator};
+
+ const EPSILON: f64 = 0.1e-3;
+
+ fn create_observations(points: &[(usize, u64)]) -> Vec<Observation> {
+ points
+ .iter()
+ .map(|(n, p)| {
+ (
+ *n,
+ BenchmarkResult {
+ min: Duration::from_nanos(*p),
+ max: Duration::from_nanos(*p),
+ avg: Duration::from_nanos(*p),
+ },
+ )
+ })
+ .collect()
+ }
+
+ fn assert_close_fit(points: &[(usize, u64)], msg: &'static str) {
+ let data = create_observations(points);
+ let estimator = Estimator::fit(&data);
+ let mse = estimator.mse(&data);
+ dbg!(&estimator, mse);
+
+ assert!(rss.abs() < EPSILON, "{} has too high mse", msg);
+ }
+
+ #[test]
+ fn test_fit_basic() {
+ assert_close_fit(&[(1, 1), (2, 1), (3, 1), (4, 1)], "constant");
+ assert_close_fit(&[(1, 1), (2, 2), (3, 3), (4, 4)], "straight line");
+ assert_close_fit(&[(1, 1), (2, 4), (3, 9), (4, 16)], "square");
+ assert_close_fit(&[(1, 1), (2, 8), (3, 27), (4, 64)], "cubic");
+ }
+
+ #[test]
+ fn test_fit_basic_largenum() {
+ assert_close_fit(
+ &[
+ (100_000, 100_000),
+ (200_000, 100_000),
+ (300_000, 100_000),
+ (400_000, 100_000),
+ ],
+ "constant",
+ );
+ assert_close_fit(
+ &[
+ (100_000, 100_000),
+ (200_000, 200_000),
+ (300_000, 300_000),
+ (400_000, 400_000),
+ ],
+ "straight line",
+ );
+ assert_close_fit(
+ &[
+ (100_000, 100_000),
+ (200_000, 400_000),
+ (300_000, 900_000),
+ (400_000, 1_600_000),
+ ],
+ "square",
+ );
+ assert_close_fit(
+ &[
+ (100_000, 100_000),
+ (200_000, 800_000),
+ (300_000, 2_700_000),
+ (400_000, 6_400_000),
+ ],
+ "cubic",
+ );
+ }
+}