aboutsummaryrefslogtreecommitdiff
path: root/src/libutil/async-semaphore.hh
blob: f8db31a683dbc4ee53e209a3e49cdfdc4f1e803c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#pragma once
/// @file
/// @brief A semaphore implementation usable from within a KJ event loop.

#include <cassert>
#include <kj/async.h>
#include <kj/common.h>
#include <kj/exception.h>
#include <kj/list.h>
#include <kj/source-location.h>
#include <memory>
#include <optional>

namespace nix {

class AsyncSemaphore
{
public:
    class [[nodiscard("destroying a semaphore guard releases the semaphore immediately")]] Token
    {
        struct Release
        {
            void operator()(AsyncSemaphore * sem) const
            {
                sem->unsafeRelease();
            }
        };

        std::unique_ptr<AsyncSemaphore, Release> parent;

    public:
        Token() = default;
        Token(AsyncSemaphore & parent, kj::Badge<AsyncSemaphore>) : parent(&parent) {}

        bool valid() const
        {
            return parent != nullptr;
        }
    };

private:
    struct Waiter
    {
        kj::PromiseFulfiller<Token> & fulfiller;
        kj::ListLink<Waiter> link;
        kj::List<Waiter, &Waiter::link> & list;

        Waiter(kj::PromiseFulfiller<Token> & fulfiller, kj::List<Waiter, &Waiter::link> & list)
            : fulfiller(fulfiller)
            , list(list)
        {
            list.add(*this);
        }

        ~Waiter()
        {
            if (link.isLinked()) {
                list.remove(*this);
            }
        }
    };

    const unsigned capacity_;
    unsigned used_ = 0;
    kj::List<Waiter, &Waiter::link> waiters;

    void unsafeRelease()
    {
        used_ -= 1;
        while (used_ < capacity_ && !waiters.empty()) {
            used_ += 1;
            auto & w = waiters.front();
            w.fulfiller.fulfill(Token{*this, {}});
            waiters.remove(w);
        }
    }

public:
    explicit AsyncSemaphore(unsigned capacity) : capacity_(capacity) {}

    KJ_DISALLOW_COPY_AND_MOVE(AsyncSemaphore);

    ~AsyncSemaphore()
    {
        assert(waiters.empty() && "destroyed a semaphore with active waiters");
    }

    std::optional<Token> tryAcquire()
    {
        if (used_ < capacity_) {
            used_ += 1;
            return Token{*this, {}};
        } else {
            return {};
        }
    }

    kj::Promise<Token> acquire()
    {
        if (auto t = tryAcquire()) {
            return std::move(*t);
        } else {
            return kj::newAdaptedPromise<Token, Waiter>(waiters);
        }
    }

    unsigned capacity() const
    {
        return capacity_;
    }

    unsigned used() const
    {
        return used_;
    }

    unsigned available() const
    {
        return capacity_ - used_;
    }
};
}