diff options
author | Aria <me@aria.rip> | 2023-11-11 16:56:19 +0000 |
---|---|---|
committer | Aria <me@aria.rip> | 2023-11-11 16:56:19 +0000 |
commit | 9a556c208c972c3f5359b3cc7b4cee0c00ac52e0 (patch) | |
tree | 8d26fda2afcca0ca2edd3fada64c845f1f876739 | |
parent | a33e49c16bae9ce74ef5481a26a748997117f623 (diff) |
feat(cli): building cost model from benchmark results
-rw-r--r-- | src/crates/cli/Cargo.toml | 2 | ||||
-rw-r--r-- | src/crates/cli/src/cost/benchmark.rs | 2 | ||||
-rw-r--r-- | src/crates/cli/src/cost/fit.rs | 44 | ||||
-rw-r--r-- | src/crates/cli/src/main.rs | 14 |
4 files changed, 24 insertions, 38 deletions
diff --git a/src/crates/cli/Cargo.toml b/src/crates/cli/Cargo.toml index d91b34c..6b3ae92 100644 --- a/src/crates/cli/Cargo.toml +++ b/src/crates/cli/Cargo.toml @@ -19,4 +19,4 @@ cargo_metadata = "0.18.1" argh = "0.1.12" glob = "0.3.1" tempfile = "3" -fitme = "1.1.0"
\ No newline at end of file +friedrich = "0.5.0" diff --git a/src/crates/cli/src/cost/benchmark.rs b/src/crates/cli/src/cost/benchmark.rs index 49e1919..a05d28a 100644 --- a/src/crates/cli/src/cost/benchmark.rs +++ b/src/crates/cli/src/cost/benchmark.rs @@ -18,7 +18,7 @@ use crate::paths::Paths; pub const ELEM_TYPE: &str = "usize"; /// String representation of the array of N values we use for benchmarking -pub const NS: &str = "[65536]"; +pub const NS: &str = "[8, 256, 1024, 65536]"; /// Run benchmarks for the given container type, returning the results. /// Panics if the given name is not in the library specs. diff --git a/src/crates/cli/src/cost/fit.rs b/src/crates/cli/src/cost/fit.rs index 7e17379..e6b3e32 100644 --- a/src/crates/cli/src/cost/fit.rs +++ b/src/crates/cli/src/cost/fit.rs @@ -3,29 +3,22 @@ use std::time::Duration; use candelabra_benchmarker::Observation; -use fitme::{expr::v1::Eq, Data, Equation, Fit, Headers, Output}; +use friedrich::{gaussian_process::GaussianProcess, kernel::Kernel, prior::Prior}; /// Fit a curve to the given set of observations. pub fn fit(results: &Vec<Observation>) -> impl Estimator { - let headers = Headers::from_iter(&["N", "T"]); - let eq = Eq::parse("m * T + c", &headers).unwrap(); - let fit = fitme::fit( - eq.clone(), - Data::new( - headers, - results - .into_iter() - .map(|(n, results)| dbg!([*n as f64, as_millis_f64(&results.avg)])), - ) - .unwrap(), - "N", - ) - .unwrap(); - - dbg!(&fit.parameter_names); - dbg!(&fit.parameter_values); - - (eq, fit) + let xs = results + .iter() + .map(|(n, _)| vec![*n as f64]) + .collect::<Vec<_>>(); + + let ys = results + .iter() + .map(|(_, results)| results.avg.as_nanos() as f64) + .collect::<Vec<_>>(); + + // TODO: Should be able to incorporate the min/max into this + GaussianProcess::default(xs, ys) } /// Can estimate a duration for a given `n`. @@ -34,15 +27,8 @@ pub trait Estimator { fn estimate(&self, n: usize) -> Duration; } -impl Estimator for (Eq, Fit) { +impl<K: Kernel, P: Prior> Estimator for GaussianProcess<K, P> { fn estimate(&self, n: usize) -> Duration { - todo!() + Duration::from_nanos(self.predict(&vec![n as f64]) as u64) } } - -fn as_millis_f64(d: &Duration) -> f64 { - let millis = d.as_millis() as f64; - let exp = 10.0_f64.powf(6.0); - let remainder_nanos = d.as_nanos() as f64 - (millis * exp); - millis + (remainder_nanos / exp) -} diff --git a/src/crates/cli/src/main.rs b/src/crates/cli/src/main.rs index bf9a084..b674ac1 100644 --- a/src/crates/cli/src/main.rs +++ b/src/crates/cli/src/main.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, io}; use anyhow::{anyhow, Context, Result}; use argh::FromArgs; @@ -7,7 +7,10 @@ use project::Project; use crate::{ candidates::CandidatesStore, - cost::{fit::fit, ResultsStore}, + cost::{ + fit::{fit, Estimator}, + ResultsStore, + }, paths::Paths, }; @@ -56,13 +59,10 @@ fn main() -> Result<()> { } info!("Found all candidate types. Running benchmarks"); - for typ in seen_types - .into_iter() - .filter(|x| x == "primrose_library::EagerSortedVec") - { + for typ in seen_types.into_iter() { let results = benchmarks.get(&typ).context("Error running benchmark")?; - for (op, results) in results.by_op.iter().filter(|(k, _)| **k == "insert") { + for (op, results) in results.by_op.iter() { debug!("Fitting curve for op {}", op); fit(results); } |