diff options
-rw-r--r-- | src/crates/candelabra/src/cost/fit.rs | 105 | ||||
-rw-r--r-- | src/crates/cli/src/model.rs | 2 |
2 files changed, 35 insertions, 72 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs index ca81118..a9262b9 100644 --- a/src/crates/candelabra/src/cost/fit.rs +++ b/src/crates/candelabra/src/cost/fit.rs @@ -6,20 +6,10 @@ use na::{Dyn, OVector}; use nalgebra::{dimension, Matrix, VecStorage}; use serde::{Deserialize, Serialize}; -/// Number of coefficients to use -const COEFFICIENTS: usize = 4; - -/// Estimates durations using a 3rd-order polynomial. -/// Value i is multiplied by x^i +/// Estimates costs using a linear regression model. #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Estimator { - pub coeffs: [f64; COEFFICIENTS], - - /// (shift, scale) - pub transform_x: (f64, f64), - - /// (shift, scale) - pub transform_y: (f64, f64), + pub coeffs: [f64; 4], } /// Approximate cost of an action. @@ -30,17 +20,9 @@ impl Estimator { /// Fit from the given set of observations, using the least squared method. pub fn fit(results: &[Observation]) -> Self { // Shift down so that ys start at 0 - let (mut xs, mut ys) = Self::to_data(results); - - let transform_x = Self::normalisation_transformation(xs.iter()); - let transform_y = Self::normalisation_transformation(ys.iter()); - - xs.iter_mut() - .for_each(|e| *e = (*e + transform_x.0) / transform_x.1); - ys.iter_mut() - .for_each(|e| *e = (*e + transform_y.0) / transform_y.1); + let (xs, ys) = Self::to_data(results); - let xv = vandermonde(&xs); + let xv = Self::prepare_input(&xs); let xtx = xv.transpose() * xv.clone(); let xtxinv = xtx.try_inverse().unwrap(); let xty = xv.transpose() * ys; @@ -48,26 +30,32 @@ impl Estimator { Self { coeffs: coeffs.into(), - transform_x, - transform_y, } } - pub fn normalisation_transformation<'a, I>(_: I) -> (f64, f64) - where - I: Iterator<Item = &'a f64>, - { - // let (min, max) = is.fold((f64::MAX, f64::MIN), |(min, max), f| { - // (min.min(*f), max.max(*f)) - // }); - // let shift = -min; - // let mut scale = 10.0 / (max - min); - // if !scale.is_normal() || scale.abs() - 1e-10 < 0.0 { - // scale = 1.0; - // } - // (-min, scale) - - (0.0, 1.0) + /// Reshape an input dataset into the required input for our model. + fn prepare_input( + xs: &[f64], + ) -> Matrix< + f64, + dimension::Dyn, + dimension::Const<4>, + VecStorage<f64, dimension::Dyn, dimension::Const<4>>, + > { + let mut mat = + Matrix::<_, dimension::Dyn, dimension::Const<4>, _>::from_element(xs.len(), 1.0); + + for (row, x) in xs.iter().enumerate() { + // First column is all 1s, for constant factor + // Linear, then powers + for col in 1..=2 { + mat[(row, col)] = x.powi(col as i32); + } + // Last column is logarithm + mat[(row, 3)] = x.log2(); + } + + mat } /// Get the normalised root mean square error with respect to some data points @@ -90,16 +78,14 @@ impl Estimator { } /// Estimate the cost of a given operation at the given `n`. - pub fn estimatef(&self, mut n: f64) -> Cost { - n = (n + self.transform_x.0) * self.transform_x.1; - - let mut raw = 0.0; - self.coeffs - .iter() - .enumerate() - .for_each(|(pow, coeff)| raw += n.powi(pow as i32) * coeff); + pub fn estimatef(&self, n: f64) -> Cost { + let mut raw = self.coeffs[0]; + for pow in 1..=2 { + raw += n.powi(pow as i32) * self.coeffs[pow]; + } + raw += n.log2() * self.coeffs[3]; - ((raw / self.transform_y.1) - self.transform_y.0).max(0.0) // can't be below 0 + raw.max(0.0) // can't be below 0 } /// Convert a list of observations to the format we use internally. @@ -113,26 +99,3 @@ impl Estimator { (xs, ys) } } - -/// Calculate a Vandermode matrix with 4 columns. -/// https://en.wikipedia.org/wiki/Vandermonde_matrix -fn vandermonde( - xs: &[f64], -) -> Matrix< - f64, - dimension::Dyn, - dimension::Const<COEFFICIENTS>, - VecStorage<f64, dimension::Dyn, dimension::Const<COEFFICIENTS>>, -> { - let mut mat = - Matrix::<_, dimension::Dyn, dimension::Const<COEFFICIENTS>, _>::from_element(xs.len(), 1.0); - - for (row, x) in xs.iter().enumerate() { - // First column is all 1s so skip - for col in 1..COEFFICIENTS { - mat[(row, col)] = x.powi(col as i32); - } - } - - mat -} diff --git a/src/crates/cli/src/model.rs b/src/crates/cli/src/model.rs index bf75fe2..bb8c367 100644 --- a/src/crates/cli/src/model.rs +++ b/src/crates/cli/src/model.rs @@ -28,7 +28,7 @@ impl State { est.coeffs .iter() .enumerate() - .map(|(pow, coeff)| (format!("x^{}", pow), *coeff)) + .map(|(pow, coeff)| (format!("coeff {}", pow), *coeff)) .chain(once(("nrmse".to_string(), est.nrmse(obvs)))), ) })); |