diff options
-rw-r--r-- | src/libstore/filetransfer.cc | 85 | ||||
-rw-r--r-- | src/libstore/filetransfer.hh | 5 | ||||
-rw-r--r-- | tests/unit/libstore/filetransfer.cc | 123 |
3 files changed, 164 insertions, 49 deletions
diff --git a/src/libstore/filetransfer.cc b/src/libstore/filetransfer.cc index 67b9fef81..fd0a42cb3 100644 --- a/src/libstore/filetransfer.cc +++ b/src/libstore/filetransfer.cc @@ -48,7 +48,7 @@ struct curlFileTransfer : public FileTransfer FileTransferResult result; Activity act; bool done = false; // whether either the success or failure function has been called - Callback<FileTransferResult> callback; + std::packaged_task<FileTransferResult(std::exception_ptr, FileTransferResult)> callback; std::function<void(TransferItem &, std::string_view data)> dataCallback; CURL * req = 0; bool active = false; // whether the handle has been added to the multi object @@ -83,14 +83,17 @@ struct curlFileTransfer : public FileTransfer TransferItem(curlFileTransfer & fileTransfer, const FileTransferRequest & request, - Callback<FileTransferResult> && callback, + std::invocable<std::exception_ptr> auto callback, std::function<void(TransferItem &, std::string_view data)> dataCallback) : fileTransfer(fileTransfer) , request(request) , act(*logger, lvlTalkative, actFileTransfer, fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri), {request.uri}, request.parentAct) - , callback(std::move(callback)) + , callback([cb{std::move(callback)}] (std::exception_ptr ex, FileTransferResult r) { + cb(ex); + return r; + }) , dataCallback(std::move(dataCallback)) { requestHeaders = curl_slist_append(requestHeaders, "Accept-Encoding: zstd, br, gzip, deflate, bzip2, xz"); @@ -123,7 +126,7 @@ struct curlFileTransfer : public FileTransfer { assert(!done); done = true; - callback.rethrow(ex); + callback(ex, std::move(result)); } template<class T> @@ -369,7 +372,7 @@ struct curlFileTransfer : public FileTransfer result.cached = httpStatus == 304; act.progress(result.bodySize, result.bodySize); done = true; - callback(std::move(result)); + callback(nullptr, std::move(result)); } else { @@ -623,7 +626,7 @@ struct curlFileTransfer : public FileTransfer } } - void enqueueItem(std::shared_ptr<TransferItem> item) + std::shared_ptr<TransferItem> enqueueItem(std::shared_ptr<TransferItem> item) { if (item->request.data && !item->request.uri.starts_with("http://") @@ -637,10 +640,11 @@ struct curlFileTransfer : public FileTransfer state->incoming.push(item); } wakeup(); + return item; } #if ENABLE_S3 - std::tuple<std::string, std::string, Store::Params> parseS3Uri(std::string uri) + static std::tuple<std::string, std::string, Store::Params> parseS3Uri(std::string uri) { auto [path, params] = splitUriAndParams(uri); @@ -655,22 +659,29 @@ struct curlFileTransfer : public FileTransfer } #endif - void enqueueFileTransfer(const FileTransferRequest & request, - Callback<FileTransferResult> callback) override + std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request) override { - enqueueFileTransfer(request, std::move(callback), {}); + return enqueueFileTransfer( + request, + [](std::exception_ptr ex) { + if (ex) { + std::rethrow_exception(ex); + } + }, + {} + ); } - void enqueueFileTransfer(const FileTransferRequest & request, - Callback<FileTransferResult> callback, + std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request, + std::invocable<std::exception_ptr> auto callback, std::function<void(TransferItem &, std::string_view data)> dataCallback) { /* Ugly hack to support s3:// URIs. */ if (request.uri.starts_with("s3://")) { // FIXME: do this on a worker thread - try { + return std::async(std::launch::deferred, [uri{request.uri}] { #if ENABLE_S3 - auto [bucketName, key, params] = parseS3Uri(request.uri); + auto [bucketName, key, params] = parseS3Uri(uri); std::string profile = getOr(params, "profile", ""); std::string region = getOr(params, "region", Aws::Region::US_EAST_1); @@ -683,19 +694,19 @@ struct curlFileTransfer : public FileTransfer auto s3Res = s3Helper.getObject(bucketName, key); FileTransferResult res; if (!s3Res.data) - throw FileTransferError(NotFound, "S3 object '%s' does not exist", request.uri); + throw FileTransferError(NotFound, "S3 object '%s' does not exist", uri); res.data = std::move(*s3Res.data); - callback(std::move(res)); + return res; #else - throw nix::Error("cannot download '%s' because Lix is not built with S3 support", request.uri); + throw nix::Error("cannot download '%s' because Lix is not built with S3 support", uri); #endif - } catch (...) { callback.rethrow(); } - return; + }); } - enqueueItem(std::make_shared<TransferItem>( - *this, request, std::move(callback), std::move(dataCallback) - )); + return enqueueItem(std::make_shared<TransferItem>( + *this, request, std::move(callback), std::move(dataCallback) + )) + ->callback.get_future(); } void download(FileTransferRequest && request, Sink & sink) override @@ -724,18 +735,15 @@ struct curlFileTransfer : public FileTransfer state->request.notify_one(); }); - enqueueFileTransfer(request, - {[_state](std::future<FileTransferResult> fut) { + enqueueFileTransfer( + request, + [_state](std::exception_ptr ex) { auto state(_state->lock()); state->done = true; - try { - fut.get(); - } catch (...) { - state->exc = std::current_exception(); - } + state->exc = ex; state->avail.notify_one(); state->request.notify_one(); - }}, + }, [_state](TransferItem & transfer, std::string_view data) { auto state(_state->lock()); @@ -758,7 +766,8 @@ struct curlFileTransfer : public FileTransfer thread. */ state->data.append(data); state->avail.notify_one(); - }); + } + ); std::unique_ptr<FinishSink> decompressor; @@ -827,20 +836,6 @@ ref<FileTransfer> makeFileTransfer() return makeCurlFileTransfer(); } -std::future<FileTransferResult> FileTransfer::enqueueFileTransfer(const FileTransferRequest & request) -{ - auto promise = std::make_shared<std::promise<FileTransferResult>>(); - enqueueFileTransfer(request, - {[promise](std::future<FileTransferResult> fut) { - try { - promise->set_value(fut.get()); - } catch (...) { - promise->set_exception(std::current_exception()); - } - }}); - return promise->get_future(); -} - FileTransferResult FileTransfer::download(const FileTransferRequest & request) { return enqueueFileTransfer(request).get(); diff --git a/src/libstore/filetransfer.hh b/src/libstore/filetransfer.hh index e028d7f70..3f55995ef 100644 --- a/src/libstore/filetransfer.hh +++ b/src/libstore/filetransfer.hh @@ -95,10 +95,7 @@ struct FileTransfer * the download. The future may throw a FileTransferError * exception. */ - virtual void enqueueFileTransfer(const FileTransferRequest & request, - Callback<FileTransferResult> callback) = 0; - - std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request); + virtual std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request) = 0; /** * Synchronously download a file. diff --git a/tests/unit/libstore/filetransfer.cc b/tests/unit/libstore/filetransfer.cc index 192bf81ef..ebd38f19d 100644 --- a/tests/unit/libstore/filetransfer.cc +++ b/tests/unit/libstore/filetransfer.cc @@ -1,12 +1,114 @@ #include "filetransfer.hh" +#include <cstdint> +#include <exception> #include <future> #include <gtest/gtest.h> +#include <netinet/in.h> +#include <stdexcept> +#include <string_view> +#include <sys/poll.h> +#include <sys/socket.h> +#include <thread> +#include <unistd.h> + +// local server tests don't work on darwin without some incantations +// the horrors do not want to look up. contributions welcome though! +#if __APPLE__ +#define NOT_ON_DARWIN(n) DISABLED_##n +#else +#define NOT_ON_DARWIN(n) n +#endif using namespace std::chrono_literals; namespace nix { +static std::tuple<uint16_t, AutoCloseFD> +serveHTTP(std::string_view status, std::string_view headers, std::function<std::string_view()> content) +{ + AutoCloseFD listener(::socket(AF_INET6, SOCK_STREAM, 0)); + if (!listener) { + throw SysError(errno, "socket() failed"); + } + + Pipe trigger; + trigger.create(); + + sockaddr_in6 addr = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_LOOPBACK_INIT, + }; + socklen_t len = sizeof(addr); + if (::bind(listener.get(), reinterpret_cast<const sockaddr *>(&addr), sizeof(addr)) < 0) { + throw SysError(errno, "bind() failed"); + } + if (::getsockname(listener.get(), reinterpret_cast<sockaddr *>(&addr), &len) < 0) { + throw SysError(errno, "getsockname() failed"); + } + if (::listen(listener.get(), 1) < 0) { + throw SysError(errno, "listen() failed"); + } + + std::thread( + [status, headers, content](AutoCloseFD socket, AutoCloseFD trigger) { + while (true) { + pollfd pfds[2] = { + { + .fd = socket.get(), + .events = POLLIN, + }, + { + .fd = trigger.get(), + .events = POLLHUP, + }, + }; + + if (::poll(pfds, 2, -1) <= 0) { + throw SysError(errno, "poll() failed"); + } + if (pfds[1].revents & POLLHUP) { + return; + } + if (!(pfds[0].revents & POLLIN)) { + continue; + } + + AutoCloseFD conn(::accept(socket.get(), nullptr, nullptr)); + if (!conn) { + throw SysError(errno, "accept() failed"); + } + + auto send = [&](std::string_view bit) { + while (!bit.empty()) { + auto written = ::write(conn.get(), bit.data(), bit.size()); + if (written < 0) { + throw SysError(errno, "write() failed"); + } + bit.remove_prefix(written); + } + }; + + send("HTTP/1.1 "); + send(status); + send("\r\n"); + send(headers); + send("\r\n"); + send(content()); + ::shutdown(conn.get(), SHUT_RDWR); + } + }, + std::move(listener), + std::move(trigger.readSide) + ) + .detach(); + + return { + ntohs(addr.sin6_port), + std::move(trigger.writeSide), + }; +} + TEST(FileTransfer, exceptionAbortsDownload) { struct Done @@ -29,4 +131,25 @@ TEST(FileTransfer, exceptionAbortsDownload) (void) new auto(std::move(reset)); } } + +TEST(FileTransfer, NOT_ON_DARWIN(reportsSetupErrors)) +{ + auto [port, srv] = serveHTTP("404 not found", "", [] { return ""; }); + auto ft = makeFileTransfer(); + ASSERT_THROW( + ft->download(FileTransferRequest(fmt("http://[::1]:%d/index", port))), + FileTransferError); +} + +TEST(FileTransfer, NOT_ON_DARWIN(reportsTransferError)) +{ + auto [port, srv] = serveHTTP("200 ok", "content-length: 100\r\n", [] { + std::this_thread::sleep_for(10ms); + return ""; + }); + auto ft = makeFileTransfer(); + FileTransferRequest req(fmt("http://[::1]:%d/index", port)); + req.baseRetryTimeMs = 0; + ASSERT_THROW(ft->download(req), FileTransferError); +} } |