diff options
author | Aria Shrimpton <me@aria.rip> | 2024-03-01 16:06:44 +0000 |
---|---|---|
committer | Aria Shrimpton <me@aria.rip> | 2024-03-01 16:06:44 +0000 |
commit | eaa9cfa58d3508f6fec2255c8ff3f9b9c66a9d8f (patch) | |
tree | a97ff4c82a7a858999e7dc1336ab6b9a0502c7ae | |
parent | 86c0b95f93979fbb2df46f1da30d3e34924e7d53 (diff) |
make fitting code generic over # of coefficients
-rw-r--r-- | src/crates/candelabra/src/cost/fit.rs | 30 | ||||
-rw-r--r-- | src/crates/cli/src/display.rs | 5 | ||||
-rw-r--r-- | src/crates/cli/src/model.rs | 29 |
3 files changed, 40 insertions, 24 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs index 6160842..c25974c 100644 --- a/src/crates/candelabra/src/cost/fit.rs +++ b/src/crates/candelabra/src/cost/fit.rs @@ -2,14 +2,19 @@ //! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master> use super::benchmark::Observation; -use na::{Dyn, MatrixXx4, OVector}; +use na::{Dyn, OVector}; +use nalgebra::{dimension, Matrix, VecStorage}; use serde::{Deserialize, Serialize}; +/// Number of coefficients to use +const COEFFICIENTS: usize = 4; + /// Estimates durations using a 3rd-order polynomial. /// Value i is multiplied by x^i #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Estimator { - pub coeffs: [f64; 4], + pub coeffs: [f64; COEFFICIENTS], + /// (shift, scale) pub transform_x: (f64, f64), @@ -86,9 +91,14 @@ impl Estimator { /// Estimate the cost of a given operation at the given `n`. pub fn estimatef(&self, mut n: f64) -> Cost { - let [a, b, c, d] = self.coeffs; n = (n + self.transform_x.0) * self.transform_x.1; - let raw = a + b * n + c * n.powi(2) + d * n.powi(3); + + let mut raw = 0.0; + self.coeffs + .iter() + .enumerate() + .for_each(|(pow, coeff)| raw += n.powi(pow as i32) * coeff); + ((raw / self.transform_y.1) - self.transform_y.0).max(0.0) // can't be below 0 } @@ -106,8 +116,16 @@ impl Estimator { /// Calculate a Vandermode matrix with 4 columns. /// https://en.wikipedia.org/wiki/Vandermonde_matrix -fn vandermonde(xs: &[f64]) -> MatrixXx4<f64> { - let mut mat = MatrixXx4::repeat(xs.len(), 1.0); +fn vandermonde( + xs: &[f64], +) -> Matrix< + f64, + dimension::Dyn, + dimension::Const<COEFFICIENTS>, + VecStorage<f64, dimension::Dyn, dimension::Const<COEFFICIENTS>>, +> { + let mut mat = + Matrix::<_, dimension::Dyn, dimension::Const<COEFFICIENTS>, _>::from_element(xs.len(), 1.0); for (row, x) in xs.iter().enumerate() { // First column is all 1s so skip diff --git a/src/crates/cli/src/display.rs b/src/crates/cli/src/display.rs index d8e2821..295e40e 100644 --- a/src/crates/cli/src/display.rs +++ b/src/crates/cli/src/display.rs @@ -12,7 +12,7 @@ impl State { I1: IntoIterator<Item = (A, I2)>, I2: IntoIterator<Item = (B, C)>, A: Display, - B: Display + Eq + Hash, + B: Display + Eq + Hash + Ord, C: Display, { // Collect everything up @@ -22,12 +22,13 @@ impl State { .collect::<Vec<(_, _)>>(); // Get keys - let keys = map + let mut keys = map .iter() .flat_map(|(_, vs)| vs.iter().map(|(k, _)| k)) .collect::<HashSet<_>>() // dedup .into_iter() .collect::<Vec<_>>(); // consistent order + keys.sort(); let mut builder = Builder::new(); builder.set_header(once("".to_string()).chain(keys.iter().map(|s| s.to_string()))); diff --git a/src/crates/cli/src/model.rs b/src/crates/cli/src/model.rs index 9379cc3..bf75fe2 100644 --- a/src/crates/cli/src/model.rs +++ b/src/crates/cli/src/model.rs @@ -1,7 +1,8 @@ +use std::iter::once; + use anyhow::Result; use argh::FromArgs; use log::info; -use tabled::{builder::Builder, settings::Style}; use crate::State; @@ -20,21 +21,17 @@ impl State { let (model, results) = self.inner.cost_info(&args.name)?; // Table of parameters - let mut builder = Builder::default(); - builder.set_header(["op", "x^0", "x^1", "x^2", "x^3", "nrmse"]); - for (k, v) in model.by_op.iter() { - let obvs = results.by_op.get(k).unwrap(); - builder.push_record(&[ - k.to_string(), - format!("{0}", v.coeffs[0]), - format!("{0}", v.coeffs[1]), - format!("{0}", v.coeffs[2]), - format!("{0}", v.coeffs[3]), - format!("{0}", v.nrmse(obvs)), - ]); - } - - self.print_table_raw(builder.build()); + self.print_table(model.by_op.iter().map(|(op, est)| { + let obvs = results.by_op.get(op).unwrap(); + ( + op, + est.coeffs + .iter() + .enumerate() + .map(|(pow, coeff)| (format!("x^{}", pow), *coeff)) + .chain(once(("nrmse".to_string(), est.nrmse(obvs)))), + ) + })); Ok(()) } |