aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAria <me@aria.rip>2023-11-11 16:56:19 +0000
committerAria <me@aria.rip>2023-11-11 16:56:19 +0000
commit9a556c208c972c3f5359b3cc7b4cee0c00ac52e0 (patch)
tree8d26fda2afcca0ca2edd3fada64c845f1f876739
parenta33e49c16bae9ce74ef5481a26a748997117f623 (diff)
feat(cli): building cost model from benchmark results
-rw-r--r--src/crates/cli/Cargo.toml2
-rw-r--r--src/crates/cli/src/cost/benchmark.rs2
-rw-r--r--src/crates/cli/src/cost/fit.rs44
-rw-r--r--src/crates/cli/src/main.rs14
4 files changed, 24 insertions, 38 deletions
diff --git a/src/crates/cli/Cargo.toml b/src/crates/cli/Cargo.toml
index d91b34c..6b3ae92 100644
--- a/src/crates/cli/Cargo.toml
+++ b/src/crates/cli/Cargo.toml
@@ -19,4 +19,4 @@ cargo_metadata = "0.18.1"
argh = "0.1.12"
glob = "0.3.1"
tempfile = "3"
-fitme = "1.1.0" \ No newline at end of file
+friedrich = "0.5.0"
diff --git a/src/crates/cli/src/cost/benchmark.rs b/src/crates/cli/src/cost/benchmark.rs
index 49e1919..a05d28a 100644
--- a/src/crates/cli/src/cost/benchmark.rs
+++ b/src/crates/cli/src/cost/benchmark.rs
@@ -18,7 +18,7 @@ use crate::paths::Paths;
pub const ELEM_TYPE: &str = "usize";
/// String representation of the array of N values we use for benchmarking
-pub const NS: &str = "[65536]";
+pub const NS: &str = "[8, 256, 1024, 65536]";
/// Run benchmarks for the given container type, returning the results.
/// Panics if the given name is not in the library specs.
diff --git a/src/crates/cli/src/cost/fit.rs b/src/crates/cli/src/cost/fit.rs
index 7e17379..e6b3e32 100644
--- a/src/crates/cli/src/cost/fit.rs
+++ b/src/crates/cli/src/cost/fit.rs
@@ -3,29 +3,22 @@
use std::time::Duration;
use candelabra_benchmarker::Observation;
-use fitme::{expr::v1::Eq, Data, Equation, Fit, Headers, Output};
+use friedrich::{gaussian_process::GaussianProcess, kernel::Kernel, prior::Prior};
/// Fit a curve to the given set of observations.
pub fn fit(results: &Vec<Observation>) -> impl Estimator {
- let headers = Headers::from_iter(&["N", "T"]);
- let eq = Eq::parse("m * T + c", &headers).unwrap();
- let fit = fitme::fit(
- eq.clone(),
- Data::new(
- headers,
- results
- .into_iter()
- .map(|(n, results)| dbg!([*n as f64, as_millis_f64(&results.avg)])),
- )
- .unwrap(),
- "N",
- )
- .unwrap();
-
- dbg!(&fit.parameter_names);
- dbg!(&fit.parameter_values);
-
- (eq, fit)
+ let xs = results
+ .iter()
+ .map(|(n, _)| vec![*n as f64])
+ .collect::<Vec<_>>();
+
+ let ys = results
+ .iter()
+ .map(|(_, results)| results.avg.as_nanos() as f64)
+ .collect::<Vec<_>>();
+
+ // TODO: Should be able to incorporate the min/max into this
+ GaussianProcess::default(xs, ys)
}
/// Can estimate a duration for a given `n`.
@@ -34,15 +27,8 @@ pub trait Estimator {
fn estimate(&self, n: usize) -> Duration;
}
-impl Estimator for (Eq, Fit) {
+impl<K: Kernel, P: Prior> Estimator for GaussianProcess<K, P> {
fn estimate(&self, n: usize) -> Duration {
- todo!()
+ Duration::from_nanos(self.predict(&vec![n as f64]) as u64)
}
}
-
-fn as_millis_f64(d: &Duration) -> f64 {
- let millis = d.as_millis() as f64;
- let exp = 10.0_f64.powf(6.0);
- let remainder_nanos = d.as_nanos() as f64 - (millis * exp);
- millis + (remainder_nanos / exp)
-}
diff --git a/src/crates/cli/src/main.rs b/src/crates/cli/src/main.rs
index bf9a084..b674ac1 100644
--- a/src/crates/cli/src/main.rs
+++ b/src/crates/cli/src/main.rs
@@ -1,4 +1,4 @@
-use std::collections::HashSet;
+use std::{collections::HashSet, io};
use anyhow::{anyhow, Context, Result};
use argh::FromArgs;
@@ -7,7 +7,10 @@ use project::Project;
use crate::{
candidates::CandidatesStore,
- cost::{fit::fit, ResultsStore},
+ cost::{
+ fit::{fit, Estimator},
+ ResultsStore,
+ },
paths::Paths,
};
@@ -56,13 +59,10 @@ fn main() -> Result<()> {
}
info!("Found all candidate types. Running benchmarks");
- for typ in seen_types
- .into_iter()
- .filter(|x| x == "primrose_library::EagerSortedVec")
- {
+ for typ in seen_types.into_iter() {
let results = benchmarks.get(&typ).context("Error running benchmark")?;
- for (op, results) in results.by_op.iter().filter(|(k, _)| **k == "insert") {
+ for (op, results) in results.by_op.iter() {
debug!("Fitting curve for op {}", op);
fit(results);
}