diff options
author | Aria Shrimpton <me@aria.rip> | 2024-03-01 16:06:44 +0000 |
---|---|---|
committer | Aria Shrimpton <me@aria.rip> | 2024-03-01 16:06:44 +0000 |
commit | eaa9cfa58d3508f6fec2255c8ff3f9b9c66a9d8f (patch) | |
tree | a97ff4c82a7a858999e7dc1336ab6b9a0502c7ae /src/crates/candelabra | |
parent | 86c0b95f93979fbb2df46f1da30d3e34924e7d53 (diff) |
make fitting code generic over # of coefficients
Diffstat (limited to 'src/crates/candelabra')
-rw-r--r-- | src/crates/candelabra/src/cost/fit.rs | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs index 6160842..c25974c 100644 --- a/src/crates/candelabra/src/cost/fit.rs +++ b/src/crates/candelabra/src/cost/fit.rs @@ -2,14 +2,19 @@ //! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master> use super::benchmark::Observation; -use na::{Dyn, MatrixXx4, OVector}; +use na::{Dyn, OVector}; +use nalgebra::{dimension, Matrix, VecStorage}; use serde::{Deserialize, Serialize}; +/// Number of coefficients to use +const COEFFICIENTS: usize = 4; + /// Estimates durations using a 3rd-order polynomial. /// Value i is multiplied by x^i #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Estimator { - pub coeffs: [f64; 4], + pub coeffs: [f64; COEFFICIENTS], + /// (shift, scale) pub transform_x: (f64, f64), @@ -86,9 +91,14 @@ impl Estimator { /// Estimate the cost of a given operation at the given `n`. pub fn estimatef(&self, mut n: f64) -> Cost { - let [a, b, c, d] = self.coeffs; n = (n + self.transform_x.0) * self.transform_x.1; - let raw = a + b * n + c * n.powi(2) + d * n.powi(3); + + let mut raw = 0.0; + self.coeffs + .iter() + .enumerate() + .for_each(|(pow, coeff)| raw += n.powi(pow as i32) * coeff); + ((raw / self.transform_y.1) - self.transform_y.0).max(0.0) // can't be below 0 } @@ -106,8 +116,16 @@ impl Estimator { /// Calculate a Vandermode matrix with 4 columns. /// https://en.wikipedia.org/wiki/Vandermonde_matrix -fn vandermonde(xs: &[f64]) -> MatrixXx4<f64> { - let mut mat = MatrixXx4::repeat(xs.len(), 1.0); +fn vandermonde( + xs: &[f64], +) -> Matrix< + f64, + dimension::Dyn, + dimension::Const<COEFFICIENTS>, + VecStorage<f64, dimension::Dyn, dimension::Const<COEFFICIENTS>>, +> { + let mut mat = + Matrix::<_, dimension::Dyn, dimension::Const<COEFFICIENTS>, _>::from_element(xs.len(), 1.0); for (row, x) in xs.iter().enumerate() { // First column is all 1s so skip |