aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAria Shrimpton <me@aria.rip>2024-03-07 14:49:12 +0000
committerAria Shrimpton <me@aria.rip>2024-03-07 14:49:12 +0000
commit8b9a33abc37f85a65db46f22de2ad6e8807cc24a (patch)
tree7c702b5281185399ff1efa4bc54759694632c636 /src
parent68942eb703238c7288dea129cdb96dcf213696b0 (diff)
use logarithmic factor in cost model
Diffstat (limited to 'src')
-rw-r--r--src/crates/candelabra/src/cost/fit.rs105
-rw-r--r--src/crates/cli/src/model.rs2
2 files changed, 35 insertions, 72 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs
index ca81118..a9262b9 100644
--- a/src/crates/candelabra/src/cost/fit.rs
+++ b/src/crates/candelabra/src/cost/fit.rs
@@ -6,20 +6,10 @@ use na::{Dyn, OVector};
use nalgebra::{dimension, Matrix, VecStorage};
use serde::{Deserialize, Serialize};
-/// Number of coefficients to use
-const COEFFICIENTS: usize = 4;
-
-/// Estimates durations using a 3rd-order polynomial.
-/// Value i is multiplied by x^i
+/// Estimates costs using a linear regression model.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Estimator {
- pub coeffs: [f64; COEFFICIENTS],
-
- /// (shift, scale)
- pub transform_x: (f64, f64),
-
- /// (shift, scale)
- pub transform_y: (f64, f64),
+ pub coeffs: [f64; 4],
}
/// Approximate cost of an action.
@@ -30,17 +20,9 @@ impl Estimator {
/// Fit from the given set of observations, using the least squared method.
pub fn fit(results: &[Observation]) -> Self {
// Shift down so that ys start at 0
- let (mut xs, mut ys) = Self::to_data(results);
-
- let transform_x = Self::normalisation_transformation(xs.iter());
- let transform_y = Self::normalisation_transformation(ys.iter());
-
- xs.iter_mut()
- .for_each(|e| *e = (*e + transform_x.0) / transform_x.1);
- ys.iter_mut()
- .for_each(|e| *e = (*e + transform_y.0) / transform_y.1);
+ let (xs, ys) = Self::to_data(results);
- let xv = vandermonde(&xs);
+ let xv = Self::prepare_input(&xs);
let xtx = xv.transpose() * xv.clone();
let xtxinv = xtx.try_inverse().unwrap();
let xty = xv.transpose() * ys;
@@ -48,26 +30,32 @@ impl Estimator {
Self {
coeffs: coeffs.into(),
- transform_x,
- transform_y,
}
}
- pub fn normalisation_transformation<'a, I>(_: I) -> (f64, f64)
- where
- I: Iterator<Item = &'a f64>,
- {
- // let (min, max) = is.fold((f64::MAX, f64::MIN), |(min, max), f| {
- // (min.min(*f), max.max(*f))
- // });
- // let shift = -min;
- // let mut scale = 10.0 / (max - min);
- // if !scale.is_normal() || scale.abs() - 1e-10 < 0.0 {
- // scale = 1.0;
- // }
- // (-min, scale)
-
- (0.0, 1.0)
+ /// Reshape an input dataset into the required input for our model.
+ fn prepare_input(
+ xs: &[f64],
+ ) -> Matrix<
+ f64,
+ dimension::Dyn,
+ dimension::Const<4>,
+ VecStorage<f64, dimension::Dyn, dimension::Const<4>>,
+ > {
+ let mut mat =
+ Matrix::<_, dimension::Dyn, dimension::Const<4>, _>::from_element(xs.len(), 1.0);
+
+ for (row, x) in xs.iter().enumerate() {
+ // First column is all 1s, for constant factor
+ // Linear, then powers
+ for col in 1..=2 {
+ mat[(row, col)] = x.powi(col as i32);
+ }
+ // Last column is logarithm
+ mat[(row, 3)] = x.log2();
+ }
+
+ mat
}
/// Get the normalised root mean square error with respect to some data points
@@ -90,16 +78,14 @@ impl Estimator {
}
/// Estimate the cost of a given operation at the given `n`.
- pub fn estimatef(&self, mut n: f64) -> Cost {
- n = (n + self.transform_x.0) * self.transform_x.1;
-
- let mut raw = 0.0;
- self.coeffs
- .iter()
- .enumerate()
- .for_each(|(pow, coeff)| raw += n.powi(pow as i32) * coeff);
+ pub fn estimatef(&self, n: f64) -> Cost {
+ let mut raw = self.coeffs[0];
+ for pow in 1..=2 {
+ raw += n.powi(pow as i32) * self.coeffs[pow];
+ }
+ raw += n.log2() * self.coeffs[3];
- ((raw / self.transform_y.1) - self.transform_y.0).max(0.0) // can't be below 0
+ raw.max(0.0) // can't be below 0
}
/// Convert a list of observations to the format we use internally.
@@ -113,26 +99,3 @@ impl Estimator {
(xs, ys)
}
}
-
-/// Calculate a Vandermode matrix with 4 columns.
-/// https://en.wikipedia.org/wiki/Vandermonde_matrix
-fn vandermonde(
- xs: &[f64],
-) -> Matrix<
- f64,
- dimension::Dyn,
- dimension::Const<COEFFICIENTS>,
- VecStorage<f64, dimension::Dyn, dimension::Const<COEFFICIENTS>>,
-> {
- let mut mat =
- Matrix::<_, dimension::Dyn, dimension::Const<COEFFICIENTS>, _>::from_element(xs.len(), 1.0);
-
- for (row, x) in xs.iter().enumerate() {
- // First column is all 1s so skip
- for col in 1..COEFFICIENTS {
- mat[(row, col)] = x.powi(col as i32);
- }
- }
-
- mat
-}
diff --git a/src/crates/cli/src/model.rs b/src/crates/cli/src/model.rs
index bf75fe2..bb8c367 100644
--- a/src/crates/cli/src/model.rs
+++ b/src/crates/cli/src/model.rs
@@ -28,7 +28,7 @@ impl State {
est.coeffs
.iter()
.enumerate()
- .map(|(pow, coeff)| (format!("x^{}", pow), *coeff))
+ .map(|(pow, coeff)| (format!("coeff {}", pow), *coeff))
.chain(once(("nrmse".to_string(), est.nrmse(obvs)))),
)
}));