aboutsummaryrefslogtreecommitdiff
path: root/src/crates/candelabra
diff options
context:
space:
mode:
authorAria Shrimpton <me@aria.rip>2024-03-01 16:06:44 +0000
committerAria Shrimpton <me@aria.rip>2024-03-01 16:06:44 +0000
commiteaa9cfa58d3508f6fec2255c8ff3f9b9c66a9d8f (patch)
treea97ff4c82a7a858999e7dc1336ab6b9a0502c7ae /src/crates/candelabra
parent86c0b95f93979fbb2df46f1da30d3e34924e7d53 (diff)
make fitting code generic over # of coefficients
Diffstat (limited to 'src/crates/candelabra')
-rw-r--r--src/crates/candelabra/src/cost/fit.rs30
1 files changed, 24 insertions, 6 deletions
diff --git a/src/crates/candelabra/src/cost/fit.rs b/src/crates/candelabra/src/cost/fit.rs
index 6160842..c25974c 100644
--- a/src/crates/candelabra/src/cost/fit.rs
+++ b/src/crates/candelabra/src/cost/fit.rs
@@ -2,14 +2,19 @@
//! Based on code from al-jshen: <https://github.com/al-jshen/compute/tree/master>
use super::benchmark::Observation;
-use na::{Dyn, MatrixXx4, OVector};
+use na::{Dyn, OVector};
+use nalgebra::{dimension, Matrix, VecStorage};
use serde::{Deserialize, Serialize};
+/// Number of coefficients to use
+const COEFFICIENTS: usize = 4;
+
/// Estimates durations using a 3rd-order polynomial.
/// Value i is multiplied by x^i
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Estimator {
- pub coeffs: [f64; 4],
+ pub coeffs: [f64; COEFFICIENTS],
+
/// (shift, scale)
pub transform_x: (f64, f64),
@@ -86,9 +91,14 @@ impl Estimator {
/// Estimate the cost of a given operation at the given `n`.
pub fn estimatef(&self, mut n: f64) -> Cost {
- let [a, b, c, d] = self.coeffs;
n = (n + self.transform_x.0) * self.transform_x.1;
- let raw = a + b * n + c * n.powi(2) + d * n.powi(3);
+
+ let mut raw = 0.0;
+ self.coeffs
+ .iter()
+ .enumerate()
+ .for_each(|(pow, coeff)| raw += n.powi(pow as i32) * coeff);
+
((raw / self.transform_y.1) - self.transform_y.0).max(0.0) // can't be below 0
}
@@ -106,8 +116,16 @@ impl Estimator {
/// Calculate a Vandermode matrix with 4 columns.
/// https://en.wikipedia.org/wiki/Vandermonde_matrix
-fn vandermonde(xs: &[f64]) -> MatrixXx4<f64> {
- let mut mat = MatrixXx4::repeat(xs.len(), 1.0);
+fn vandermonde(
+ xs: &[f64],
+) -> Matrix<
+ f64,
+ dimension::Dyn,
+ dimension::Const<COEFFICIENTS>,
+ VecStorage<f64, dimension::Dyn, dimension::Const<COEFFICIENTS>>,
+> {
+ let mut mat =
+ Matrix::<_, dimension::Dyn, dimension::Const<COEFFICIENTS>, _>::from_element(xs.len(), 1.0);
for (row, x) in xs.iter().enumerate() {
// First column is all 1s so skip