aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoreldritch horrors <pennae@lix.systems>2024-09-26 17:23:52 +0000
committerGerrit Code Review <gerrit@localhost>2024-09-26 17:23:52 +0000
commit619a93bd54386714241fc0edfba2aa9b4c40355b (patch)
tree5ef0f99ef1cee0917bf9576896b5142a04a333a7
parent5dc7671d81845cc9832752209daa591b401ae0c9 (diff)
parent531d040e8c2d211408c84ae23421aaa45b3b5a7a (diff)
Merge "libutil: add async collection mechanism" into main
-rw-r--r--src/libutil/async-collect.hh101
-rw-r--r--src/libutil/meson.build1
-rw-r--r--tests/unit/libutil/async-collect.cc104
-rw-r--r--tests/unit/meson.build1
4 files changed, 207 insertions, 0 deletions
diff --git a/src/libutil/async-collect.hh b/src/libutil/async-collect.hh
new file mode 100644
index 000000000..9e0b8bad9
--- /dev/null
+++ b/src/libutil/async-collect.hh
@@ -0,0 +1,101 @@
+#pragma once
+/// @file
+
+#include <kj/async.h>
+#include <kj/common.h>
+#include <kj/vector.h>
+#include <list>
+#include <optional>
+#include <type_traits>
+
+namespace nix {
+
+template<typename K, typename V>
+class AsyncCollect
+{
+public:
+ using Item = std::conditional_t<std::is_void_v<V>, K, std::pair<K, V>>;
+
+private:
+ kj::ForkedPromise<void> allPromises;
+ std::list<Item> results;
+ size_t remaining;
+
+ kj::ForkedPromise<void> signal;
+ kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> notify;
+
+ void oneDone(Item item)
+ {
+ results.emplace_back(std::move(item));
+ remaining -= 1;
+ KJ_IF_MAYBE (n, notify) {
+ (*n)->fulfill();
+ notify = nullptr;
+ }
+ }
+
+ kj::Promise<void> collectorFor(K key, kj::Promise<V> promise)
+ {
+ if constexpr (std::is_void_v<V>) {
+ return promise.then([this, key{std::move(key)}] { oneDone(std::move(key)); });
+ } else {
+ return promise.then([this, key{std::move(key)}](V v) {
+ oneDone(Item{std::move(key), std::move(v)});
+ });
+ }
+ }
+
+ kj::ForkedPromise<void> waitForAll(kj::Array<std::pair<K, kj::Promise<V>>> & promises)
+ {
+ kj::Vector<kj::Promise<void>> wrappers;
+ for (auto & [key, promise] : promises) {
+ wrappers.add(collectorFor(std::move(key), std::move(promise)));
+ }
+
+ return kj::joinPromisesFailFast(wrappers.releaseAsArray()).fork();
+ }
+
+public:
+ AsyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> && promises)
+ : allPromises(waitForAll(promises))
+ , remaining(promises.size())
+ , signal{nullptr}
+ {
+ }
+
+ kj::Promise<std::optional<Item>> next()
+ {
+ if (remaining == 0 && results.empty()) {
+ return {std::nullopt};
+ }
+
+ if (!results.empty()) {
+ auto result = std::move(results.front());
+ results.pop_front();
+ return {{std::move(result)}};
+ }
+
+ if (notify == nullptr) {
+ auto pair = kj::newPromiseAndFulfiller<void>();
+ notify = std::move(pair.fulfiller);
+ signal = pair.promise.fork();
+ }
+
+ return signal.addBranch().exclusiveJoin(allPromises.addBranch()).then([this] {
+ return next();
+ });
+ }
+};
+
+/**
+ * Collect the results of a list of promises, in order of completion.
+ * Once any input promise is rejected all promises that have not been
+ * resolved or rejected will be cancelled and the exception rethrown.
+ */
+template<typename K, typename V>
+AsyncCollect<K, V> asyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> promises)
+{
+ return AsyncCollect<K, V>(std::move(promises));
+}
+
+}
diff --git a/src/libutil/meson.build b/src/libutil/meson.build
index 89eeed133..afca4e021 100644
--- a/src/libutil/meson.build
+++ b/src/libutil/meson.build
@@ -53,6 +53,7 @@ libutil_headers = files(
'archive.hh',
'args/root.hh',
'args.hh',
+ 'async-collect.hh',
'async-semaphore.hh',
'backed-string-view.hh',
'box_ptr.hh',
diff --git a/tests/unit/libutil/async-collect.cc b/tests/unit/libutil/async-collect.cc
new file mode 100644
index 000000000..770374d21
--- /dev/null
+++ b/tests/unit/libutil/async-collect.cc
@@ -0,0 +1,104 @@
+#include "async-collect.hh"
+
+#include <gtest/gtest.h>
+#include <kj/array.h>
+#include <kj/async.h>
+#include <kj/exception.h>
+#include <stdexcept>
+
+namespace nix {
+
+TEST(AsyncCollect, void)
+{
+ kj::EventLoop loop;
+ kj::WaitScope waitScope(loop);
+
+ auto a = kj::newPromiseAndFulfiller<void>();
+ auto b = kj::newPromiseAndFulfiller<void>();
+ auto c = kj::newPromiseAndFulfiller<void>();
+ auto d = kj::newPromiseAndFulfiller<void>();
+
+ auto collect = asyncCollect(kj::arr(
+ std::pair(1, std::move(a.promise)),
+ std::pair(2, std::move(b.promise)),
+ std::pair(3, std::move(c.promise)),
+ std::pair(4, std::move(d.promise))
+ ));
+
+ auto p = collect.next();
+ ASSERT_FALSE(p.poll(waitScope));
+
+ // collection is ordered
+ c.fulfiller->fulfill();
+ b.fulfiller->fulfill();
+
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_EQ(p.wait(waitScope), 3);
+
+ p = collect.next();
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_EQ(p.wait(waitScope), 2);
+
+ p = collect.next();
+ ASSERT_FALSE(p.poll(waitScope));
+
+ // exceptions propagate
+ a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
+
+ p = collect.next();
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_THROW(p.wait(waitScope), kj::Exception);
+
+ // first exception aborts collection
+ p = collect.next();
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_THROW(p.wait(waitScope), kj::Exception);
+}
+
+TEST(AsyncCollect, nonVoid)
+{
+ kj::EventLoop loop;
+ kj::WaitScope waitScope(loop);
+
+ auto a = kj::newPromiseAndFulfiller<int>();
+ auto b = kj::newPromiseAndFulfiller<int>();
+ auto c = kj::newPromiseAndFulfiller<int>();
+ auto d = kj::newPromiseAndFulfiller<int>();
+
+ auto collect = asyncCollect(kj::arr(
+ std::pair(1, std::move(a.promise)),
+ std::pair(2, std::move(b.promise)),
+ std::pair(3, std::move(c.promise)),
+ std::pair(4, std::move(d.promise))
+ ));
+
+ auto p = collect.next();
+ ASSERT_FALSE(p.poll(waitScope));
+
+ // collection is ordered
+ c.fulfiller->fulfill(1);
+ b.fulfiller->fulfill(2);
+
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_EQ(p.wait(waitScope), std::pair(3, 1));
+
+ p = collect.next();
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_EQ(p.wait(waitScope), std::pair(2, 2));
+
+ p = collect.next();
+ ASSERT_FALSE(p.poll(waitScope));
+
+ // exceptions propagate
+ a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
+
+ p = collect.next();
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_THROW(p.wait(waitScope), kj::Exception);
+
+ // first exception aborts collection
+ p = collect.next();
+ ASSERT_TRUE(p.poll(waitScope));
+ ASSERT_THROW(p.wait(waitScope), kj::Exception);
+}
+}
diff --git a/tests/unit/meson.build b/tests/unit/meson.build
index 3d3930731..8b0c66dd8 100644
--- a/tests/unit/meson.build
+++ b/tests/unit/meson.build
@@ -39,6 +39,7 @@ liblixutil_test_support = declare_dependency(
)
libutil_tests_sources = files(
+ 'libutil/async-collect.cc',
'libutil/async-semaphore.cc',
'libutil/canon-path.cc',
'libutil/checked-arithmetic.cc',