aboutsummaryrefslogtreecommitdiff
path: root/incria/src/thunk.rs
blob: f076ba1a6bdf5dcc78678b0b9824a1a2894ba78e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//! A mapper based on a function from keys to values.
use std::{collections::HashMap, hash::Hash, pin::Pin, sync::Arc};

use tokio::sync::{Mutex, Notify, RwLock, RwLockReadGuard};

use crate::{
    deps::{self, NodeId},
    Mapper,
};

/// A mapping that lazily computes values with a given async thunk and memoizes them.
#[derive(Debug, Default)]
pub struct ThunkMapper<K, V, T> {
    /// The thunk we use to compute values
    thunk: T,

    /// Map of all values we've calculated already
    calculated: RwLock<HashMap<K, (V, NodeId)>>,

    /// Values we're currently calculating
    /// The Notify will be triggered when computation is done.
    waiting: Mutex<HashMap<K, Arc<Notify>>>,
}

impl<K: Clone + Eq + Hash + Send + Sync, V: Send + Sync, T: Thunk<K, V>> ThunkMapper<K, V, T> {
    /// Create a new instance of the computing store.
    pub fn new(thunk: T) -> Self {
        Self {
            thunk,
            calculated: RwLock::default(),
            waiting: Mutex::default(),
        }
    }
}

impl<K: Clone + Eq + Hash + Send + Sync, V: Send + Sync, C: Thunk<K, V>> Mapper
    for ThunkMapper<K, V, C>
{
    type Key = K;
    type Value = V;
    type Wrapper<'a> = RwLockReadGuard<'a, V> where V: 'a;

    async fn get<'a>(&'a self, key: &Self::Key) -> Self::Wrapper<'a> {
        // Attempt to reuse or evict the existing value
        let mut reuse_dep_id = None;
        {
            let finished = self.calculated.read().await;
            if let Some((_, dep)) = finished.get(key) {
                if !deps::is_dirty(*dep) {
                    deps::add_dep(*dep);
                    return RwLockReadGuard::map(finished, |hm| &hm.get(key).unwrap().0);
                } else {
                    reuse_dep_id = Some(*dep);
                    drop(finished);
                    // Dirty, so we'll recompute below but we should remove it now
                    if self.calculated.write().await.remove(key).is_none() {
                        // Someone else already noticed it was dirty and removed it before us, so we need to deal with that
                        todo!("dirty value removed between us noticing and us doing something")
                    }
                }
            }
        }

        let barrier = self.waiting.lock().await.get(key).cloned();
        if let Some(barrier) = barrier {
            // Waiting for completion
            barrier.notified().await;

            let val = RwLockReadGuard::map(self.calculated.read().await, |hm| hm.get(key).unwrap());
            deps::add_dep(val.1);

            return RwLockReadGuard::map(val, |inf| &inf.0);
        } else {
            // Needs calculated
            let notify = Arc::new(Notify::new());
            self.waiting
                .lock()
                .await
                .insert(key.clone(), notify.clone());

            let dep = if let Some(x) = reuse_dep_id {
                deps::clear(x);
                x
            } else {
                deps::next_node_id()
            };

            let val = deps::with_node_id(dep, self.thunk.compute(key.clone())).await;
            deps::add_dep(dep);

            self.calculated
                .write()
                .await
                .insert(key.clone(), (val, dep));
            self.waiting.lock().await.remove(key);

            notify.notify_waiters();

            return RwLockReadGuard::map(self.calculated.read().await, |hm| {
                &hm.get(key).unwrap().0
            });
        }
    }
}

/// A function from keys to values.
///
/// Should be pure, except for use of other mappings. This ensures recomputation is done when needed.
pub trait Thunk<K, V>: Send + 'static {
    fn compute(&self, key: K) -> Pin<Box<dyn std::future::Future<Output = V> + Send + '_>>;
}