aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAria Shrimpton <me@aria.rip>2024-03-01 16:06:44 +0000
committerAria Shrimpton <me@aria.rip>2024-03-01 16:06:44 +0000
commiteaa9cfa58d3508f6fec2255c8ff3f9b9c66a9d8f (patch)
treea97ff4c82a7a858999e7dc1336ab6b9a0502c7ae
parent86c0b95f93979fbb2df46f1da30d3e34924e7d53 (diff)
make fitting code generic over # of coefficients
-rw-r--r--src/crates/candelabra/src/cost/fit.rs30
-rw-r--r--src/crates/cli/src/display.rs5
-rw-r--r--src/crates/cli/src/model.rs29
3 files changed, 40 insertions, 24 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs
index 6160842..c25974c 100644
--- a/src/crates/candelabra/src/cost/fit.rs
+++ b/src/crates/candelabra/src/cost/fit.rs
@@ -2,14 +2,19 @@
//! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master>
use super::benchmark::Observation;
-use na::{Dyn, MatrixXx4, OVector};
+use na::{Dyn, OVector};
+use nalgebra::{dimension, Matrix, VecStorage};
use serde::{Deserialize, Serialize};
+/// Number of coefficients to use
+const COEFFICIENTS: usize = 4;
+
/// Estimates durations using a 3rd-order polynomial.
/// Value i is multiplied by x^i
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Estimator {
- pub coeffs: [f64; 4],
+ pub coeffs: [f64; COEFFICIENTS],
+
/// (shift, scale)
pub transform_x: (f64, f64),
@@ -86,9 +91,14 @@ impl Estimator {
/// Estimate the cost of a given operation at the given `n`.
pub fn estimatef(&self, mut n: f64) -> Cost {
- let [a, b, c, d] = self.coeffs;
n = (n + self.transform_x.0) * self.transform_x.1;
- let raw = a + b * n + c * n.powi(2) + d * n.powi(3);
+
+ let mut raw = 0.0;
+ self.coeffs
+ .iter()
+ .enumerate()
+ .for_each(|(pow, coeff)| raw += n.powi(pow as i32) * coeff);
+
((raw / self.transform_y.1) - self.transform_y.0).max(0.0) // can't be below 0
}
@@ -106,8 +116,16 @@ impl Estimator {
/// Calculate a Vandermode matrix with 4 columns.
/// https://en.wikipedia.org/wiki/Vandermonde_matrix
-fn vandermonde(xs: &[f64]) -> MatrixXx4<f64> {
- let mut mat = MatrixXx4::repeat(xs.len(), 1.0);
+fn vandermonde(
+ xs: &[f64],
+) -> Matrix<
+ f64,
+ dimension::Dyn,
+ dimension::Const<COEFFICIENTS>,
+ VecStorage<f64, dimension::Dyn, dimension::Const<COEFFICIENTS>>,
+> {
+ let mut mat =
+ Matrix::<_, dimension::Dyn, dimension::Const<COEFFICIENTS>, _>::from_element(xs.len(), 1.0);
for (row, x) in xs.iter().enumerate() {
// First column is all 1s so skip
diff --git a/src/crates/cli/src/display.rs b/src/crates/cli/src/display.rs
index d8e2821..295e40e 100644
--- a/src/crates/cli/src/display.rs
+++ b/src/crates/cli/src/display.rs
@@ -12,7 +12,7 @@ impl State {
I1: IntoIterator<Item = (A, I2)>,
I2: IntoIterator<Item = (B, C)>,
A: Display,
- B: Display + Eq + Hash,
+ B: Display + Eq + Hash + Ord,
C: Display,
{
// Collect everything up
@@ -22,12 +22,13 @@ impl State {
.collect::<Vec<(_, _)>>();
// Get keys
- let keys = map
+ let mut keys = map
.iter()
.flat_map(|(_, vs)| vs.iter().map(|(k, _)| k))
.collect::<HashSet<_>>() // dedup
.into_iter()
.collect::<Vec<_>>(); // consistent order
+ keys.sort();
let mut builder = Builder::new();
builder.set_header(once("".to_string()).chain(keys.iter().map(|s| s.to_string())));
diff --git a/src/crates/cli/src/model.rs b/src/crates/cli/src/model.rs
index 9379cc3..bf75fe2 100644
--- a/src/crates/cli/src/model.rs
+++ b/src/crates/cli/src/model.rs
@@ -1,7 +1,8 @@
+use std::iter::once;
+
use anyhow::Result;
use argh::FromArgs;
use log::info;
-use tabled::{builder::Builder, settings::Style};
use crate::State;
@@ -20,21 +21,17 @@ impl State {
let (model, results) = self.inner.cost_info(&args.name)?;
// Table of parameters
- let mut builder = Builder::default();
- builder.set_header(["op", "x^0", "x^1", "x^2", "x^3", "nrmse"]);
- for (k, v) in model.by_op.iter() {
- let obvs = results.by_op.get(k).unwrap();
- builder.push_record(&[
- k.to_string(),
- format!("{0}", v.coeffs[0]),
- format!("{0}", v.coeffs[1]),
- format!("{0}", v.coeffs[2]),
- format!("{0}", v.coeffs[3]),
- format!("{0}", v.nrmse(obvs)),
- ]);
- }
-
- self.print_table_raw(builder.build());
+ self.print_table(model.by_op.iter().map(|(op, est)| {
+ let obvs = results.by_op.get(op).unwrap();
+ (
+ op,
+ est.coeffs
+ .iter()
+ .enumerate()
+ .map(|(pow, coeff)| (format!("x^{}", pow), *coeff))
+ .chain(once(("nrmse".to_string(), est.nrmse(obvs)))),
+ )
+ }));
Ok(())
}