aboutsummaryrefslogtreecommitdiff
path: root/incria/benches/pascal.rs
blob: 01befe5b9472bf870654f6eb37f79d735897596e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use std::{future::Future, pin::Pin, sync::OnceLock};

use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use incria::{
    deps,
    thunk::{Thunk, ThunkMapper},
    Mapper,
};

type Key = (isize, isize);
type Value = usize;

#[derive(Debug, Default)]
struct PascalThunk;

impl Thunk<Key, Value> for PascalThunk {
    fn compute(&self, key: Key) -> Pin<Box<dyn Future<Output = Value> + Send + '_>> {
        Box::pin(async move {
            if key.1 > key.0 || key.0 < 0 {
                0 // anything outside the triangle is 0
            } else if key == (0, 0) {
                1
            } else {
                let (left, below) = ((key.0 - 1, key.1 - 1), (key.0 - 1, key.1));
                // let (left, below) =
                //     join!(pascal_mapping().get(&left), pascal_mapping().get(&below));

                let left = *pascal_mapping().get(&left).await;
                let below = *pascal_mapping().get(&below).await;
                let val = left + below;

                if val > 10_000 {
                    1
                } else {
                    val
                }
            }
        })
    }
}

type PascalMapping = ThunkMapper<Key, Value, PascalThunk>;
static PASCAL_MAPPING: OnceLock<PascalMapping> = OnceLock::new();
fn pascal_mapping() -> &'static PascalMapping {
    PASCAL_MAPPING.get_or_init(PascalMapping::default)
}

fn criterion_benchmark(c: &mut Criterion) {
    benchmarks_with_n(c, 5);
    benchmarks_with_n(c, 10);
    benchmarks_with_n(c, 40);
}

fn benchmarks_with_n(c: &mut Criterion, n: isize) {
    let target_cell = (n, n / 2);
    c.bench_function(&format!("pascal fresh (n = {})", n), |b| {
        b.to_async(tokio::runtime::Runtime::new().unwrap())
            .iter(|| do_calc(&target_cell));
    });
    c.bench_function(&format!("pascal fully invalidated (n = {})", n), |b| {
        b.to_async(tokio::runtime::Runtime::new().unwrap())
            .iter_batched(
                || do_then_invalidate(target_cell, (0, 0)),
                |_| do_calc(&target_cell),
                BatchSize::SmallInput,
            );
    });
    c.bench_function(&format!("pascal 1/2 invalidated (n = {})", n), |b| {
        b.to_async(tokio::runtime::Runtime::new().unwrap())
            .iter_batched(
                || do_then_invalidate(target_cell, (n / 2, n / 2)),
                |_| do_calc(&target_cell),
                BatchSize::SmallInput,
            );
    });
    c.bench_function(&format!("pascal 1 invalidated (n = {})", n), |b| {
        b.to_async(tokio::runtime::Runtime::new().unwrap())
            .iter_batched(
                || do_then_invalidate(target_cell, target_cell),
                |_| do_calc(&target_cell),
                BatchSize::SmallInput,
            );
    });
}

#[inline(always)]
async fn do_calc(key: &Key) {
    deps::with_node_id(deps::next_node_id(), pascal_mapping().get(key)).await;
}

#[inline(always)]
async fn do_then_invalidate(eval: Key, inval: Key) {
    // Do calculation
    do_calc(&eval).await;

    // Invalidate the root
    deps::mark_dirty(pascal_mapping().dep_id(&inval).await.unwrap());
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

#[tokio::test]
async fn test_pascal_triangle_correct() {
    assert_eq!(
        *deps::with_node_id(deps::next_node_id(), pascal_mapping().get(&(0, 0))).await,
        1,
        "unique non-zero entry at (0, 0)"
    );
    assert_eq!(
        *deps::with_node_id(deps::next_node_id(), pascal_mapping().get(&(0, 1))).await,
        0,
        "numbers outside triangle are 0"
    );
    assert_eq!(
        *deps::with_node_id(deps::next_node_id(), pascal_mapping().get(&(4, 2))).await,
        6,
        "correct value for (4, 2)"
    );
    assert_eq!(
        *deps::with_node_id(deps::next_node_id(), pascal_mapping().get(&(8, 4))).await,
        70,
        "correct value for (8, 4)"
    );
}