aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/libstore/filetransfer.cc85
-rw-r--r--src/libstore/filetransfer.hh5
-rw-r--r--tests/unit/libstore/filetransfer.cc123
3 files changed, 164 insertions, 49 deletions
diff --git a/src/libstore/filetransfer.cc b/src/libstore/filetransfer.cc
index 67b9fef81..fd0a42cb3 100644
--- a/src/libstore/filetransfer.cc
+++ b/src/libstore/filetransfer.cc
@@ -48,7 +48,7 @@ struct curlFileTransfer : public FileTransfer
FileTransferResult result;
Activity act;
bool done = false; // whether either the success or failure function has been called
- Callback<FileTransferResult> callback;
+ std::packaged_task<FileTransferResult(std::exception_ptr, FileTransferResult)> callback;
std::function<void(TransferItem &, std::string_view data)> dataCallback;
CURL * req = 0;
bool active = false; // whether the handle has been added to the multi object
@@ -83,14 +83,17 @@ struct curlFileTransfer : public FileTransfer
TransferItem(curlFileTransfer & fileTransfer,
const FileTransferRequest & request,
- Callback<FileTransferResult> && callback,
+ std::invocable<std::exception_ptr> auto callback,
std::function<void(TransferItem &, std::string_view data)> dataCallback)
: fileTransfer(fileTransfer)
, request(request)
, act(*logger, lvlTalkative, actFileTransfer,
fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri),
{request.uri}, request.parentAct)
- , callback(std::move(callback))
+ , callback([cb{std::move(callback)}] (std::exception_ptr ex, FileTransferResult r) {
+ cb(ex);
+ return r;
+ })
, dataCallback(std::move(dataCallback))
{
requestHeaders = curl_slist_append(requestHeaders, "Accept-Encoding: zstd, br, gzip, deflate, bzip2, xz");
@@ -123,7 +126,7 @@ struct curlFileTransfer : public FileTransfer
{
assert(!done);
done = true;
- callback.rethrow(ex);
+ callback(ex, std::move(result));
}
template<class T>
@@ -369,7 +372,7 @@ struct curlFileTransfer : public FileTransfer
result.cached = httpStatus == 304;
act.progress(result.bodySize, result.bodySize);
done = true;
- callback(std::move(result));
+ callback(nullptr, std::move(result));
}
else {
@@ -623,7 +626,7 @@ struct curlFileTransfer : public FileTransfer
}
}
- void enqueueItem(std::shared_ptr<TransferItem> item)
+ std::shared_ptr<TransferItem> enqueueItem(std::shared_ptr<TransferItem> item)
{
if (item->request.data
&& !item->request.uri.starts_with("http://")
@@ -637,10 +640,11 @@ struct curlFileTransfer : public FileTransfer
state->incoming.push(item);
}
wakeup();
+ return item;
}
#if ENABLE_S3
- std::tuple<std::string, std::string, Store::Params> parseS3Uri(std::string uri)
+ static std::tuple<std::string, std::string, Store::Params> parseS3Uri(std::string uri)
{
auto [path, params] = splitUriAndParams(uri);
@@ -655,22 +659,29 @@ struct curlFileTransfer : public FileTransfer
}
#endif
- void enqueueFileTransfer(const FileTransferRequest & request,
- Callback<FileTransferResult> callback) override
+ std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request) override
{
- enqueueFileTransfer(request, std::move(callback), {});
+ return enqueueFileTransfer(
+ request,
+ [](std::exception_ptr ex) {
+ if (ex) {
+ std::rethrow_exception(ex);
+ }
+ },
+ {}
+ );
}
- void enqueueFileTransfer(const FileTransferRequest & request,
- Callback<FileTransferResult> callback,
+ std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request,
+ std::invocable<std::exception_ptr> auto callback,
std::function<void(TransferItem &, std::string_view data)> dataCallback)
{
/* Ugly hack to support s3:// URIs. */
if (request.uri.starts_with("s3://")) {
// FIXME: do this on a worker thread
- try {
+ return std::async(std::launch::deferred, [uri{request.uri}] {
#if ENABLE_S3
- auto [bucketName, key, params] = parseS3Uri(request.uri);
+ auto [bucketName, key, params] = parseS3Uri(uri);
std::string profile = getOr(params, "profile", "");
std::string region = getOr(params, "region", Aws::Region::US_EAST_1);
@@ -683,19 +694,19 @@ struct curlFileTransfer : public FileTransfer
auto s3Res = s3Helper.getObject(bucketName, key);
FileTransferResult res;
if (!s3Res.data)
- throw FileTransferError(NotFound, "S3 object '%s' does not exist", request.uri);
+ throw FileTransferError(NotFound, "S3 object '%s' does not exist", uri);
res.data = std::move(*s3Res.data);
- callback(std::move(res));
+ return res;
#else
- throw nix::Error("cannot download '%s' because Lix is not built with S3 support", request.uri);
+ throw nix::Error("cannot download '%s' because Lix is not built with S3 support", uri);
#endif
- } catch (...) { callback.rethrow(); }
- return;
+ });
}
- enqueueItem(std::make_shared<TransferItem>(
- *this, request, std::move(callback), std::move(dataCallback)
- ));
+ return enqueueItem(std::make_shared<TransferItem>(
+ *this, request, std::move(callback), std::move(dataCallback)
+ ))
+ ->callback.get_future();
}
void download(FileTransferRequest && request, Sink & sink) override
@@ -724,18 +735,15 @@ struct curlFileTransfer : public FileTransfer
state->request.notify_one();
});
- enqueueFileTransfer(request,
- {[_state](std::future<FileTransferResult> fut) {
+ enqueueFileTransfer(
+ request,
+ [_state](std::exception_ptr ex) {
auto state(_state->lock());
state->done = true;
- try {
- fut.get();
- } catch (...) {
- state->exc = std::current_exception();
- }
+ state->exc = ex;
state->avail.notify_one();
state->request.notify_one();
- }},
+ },
[_state](TransferItem & transfer, std::string_view data) {
auto state(_state->lock());
@@ -758,7 +766,8 @@ struct curlFileTransfer : public FileTransfer
thread. */
state->data.append(data);
state->avail.notify_one();
- });
+ }
+ );
std::unique_ptr<FinishSink> decompressor;
@@ -827,20 +836,6 @@ ref<FileTransfer> makeFileTransfer()
return makeCurlFileTransfer();
}
-std::future<FileTransferResult> FileTransfer::enqueueFileTransfer(const FileTransferRequest & request)
-{
- auto promise = std::make_shared<std::promise<FileTransferResult>>();
- enqueueFileTransfer(request,
- {[promise](std::future<FileTransferResult> fut) {
- try {
- promise->set_value(fut.get());
- } catch (...) {
- promise->set_exception(std::current_exception());
- }
- }});
- return promise->get_future();
-}
-
FileTransferResult FileTransfer::download(const FileTransferRequest & request)
{
return enqueueFileTransfer(request).get();
diff --git a/src/libstore/filetransfer.hh b/src/libstore/filetransfer.hh
index e028d7f70..3f55995ef 100644
--- a/src/libstore/filetransfer.hh
+++ b/src/libstore/filetransfer.hh
@@ -95,10 +95,7 @@ struct FileTransfer
* the download. The future may throw a FileTransferError
* exception.
*/
- virtual void enqueueFileTransfer(const FileTransferRequest & request,
- Callback<FileTransferResult> callback) = 0;
-
- std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request);
+ virtual std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request) = 0;
/**
* Synchronously download a file.
diff --git a/tests/unit/libstore/filetransfer.cc b/tests/unit/libstore/filetransfer.cc
index 192bf81ef..ebd38f19d 100644
--- a/tests/unit/libstore/filetransfer.cc
+++ b/tests/unit/libstore/filetransfer.cc
@@ -1,12 +1,114 @@
#include "filetransfer.hh"
+#include <cstdint>
+#include <exception>
#include <future>
#include <gtest/gtest.h>
+#include <netinet/in.h>
+#include <stdexcept>
+#include <string_view>
+#include <sys/poll.h>
+#include <sys/socket.h>
+#include <thread>
+#include <unistd.h>
+
+// local server tests don't work on darwin without some incantations
+// the horrors do not want to look up. contributions welcome though!
+#if __APPLE__
+#define NOT_ON_DARWIN(n) DISABLED_##n
+#else
+#define NOT_ON_DARWIN(n) n
+#endif
using namespace std::chrono_literals;
namespace nix {
+static std::tuple<uint16_t, AutoCloseFD>
+serveHTTP(std::string_view status, std::string_view headers, std::function<std::string_view()> content)
+{
+ AutoCloseFD listener(::socket(AF_INET6, SOCK_STREAM, 0));
+ if (!listener) {
+ throw SysError(errno, "socket() failed");
+ }
+
+ Pipe trigger;
+ trigger.create();
+
+ sockaddr_in6 addr = {
+ .sin6_family = AF_INET6,
+ .sin6_addr = IN6ADDR_LOOPBACK_INIT,
+ };
+ socklen_t len = sizeof(addr);
+ if (::bind(listener.get(), reinterpret_cast<const sockaddr *>(&addr), sizeof(addr)) < 0) {
+ throw SysError(errno, "bind() failed");
+ }
+ if (::getsockname(listener.get(), reinterpret_cast<sockaddr *>(&addr), &len) < 0) {
+ throw SysError(errno, "getsockname() failed");
+ }
+ if (::listen(listener.get(), 1) < 0) {
+ throw SysError(errno, "listen() failed");
+ }
+
+ std::thread(
+ [status, headers, content](AutoCloseFD socket, AutoCloseFD trigger) {
+ while (true) {
+ pollfd pfds[2] = {
+ {
+ .fd = socket.get(),
+ .events = POLLIN,
+ },
+ {
+ .fd = trigger.get(),
+ .events = POLLHUP,
+ },
+ };
+
+ if (::poll(pfds, 2, -1) <= 0) {
+ throw SysError(errno, "poll() failed");
+ }
+ if (pfds[1].revents & POLLHUP) {
+ return;
+ }
+ if (!(pfds[0].revents & POLLIN)) {
+ continue;
+ }
+
+ AutoCloseFD conn(::accept(socket.get(), nullptr, nullptr));
+ if (!conn) {
+ throw SysError(errno, "accept() failed");
+ }
+
+ auto send = [&](std::string_view bit) {
+ while (!bit.empty()) {
+ auto written = ::write(conn.get(), bit.data(), bit.size());
+ if (written < 0) {
+ throw SysError(errno, "write() failed");
+ }
+ bit.remove_prefix(written);
+ }
+ };
+
+ send("HTTP/1.1 ");
+ send(status);
+ send("\r\n");
+ send(headers);
+ send("\r\n");
+ send(content());
+ ::shutdown(conn.get(), SHUT_RDWR);
+ }
+ },
+ std::move(listener),
+ std::move(trigger.readSide)
+ )
+ .detach();
+
+ return {
+ ntohs(addr.sin6_port),
+ std::move(trigger.writeSide),
+ };
+}
+
TEST(FileTransfer, exceptionAbortsDownload)
{
struct Done
@@ -29,4 +131,25 @@ TEST(FileTransfer, exceptionAbortsDownload)
(void) new auto(std::move(reset));
}
}
+
+TEST(FileTransfer, NOT_ON_DARWIN(reportsSetupErrors))
+{
+ auto [port, srv] = serveHTTP("404 not found", "", [] { return ""; });
+ auto ft = makeFileTransfer();
+ ASSERT_THROW(
+ ft->download(FileTransferRequest(fmt("http://[::1]:%d/index", port))),
+ FileTransferError);
+}
+
+TEST(FileTransfer, NOT_ON_DARWIN(reportsTransferError))
+{
+ auto [port, srv] = serveHTTP("200 ok", "content-length: 100\r\n", [] {
+ std::this_thread::sleep_for(10ms);
+ return "";
+ });
+ auto ft = makeFileTransfer();
+ FileTransferRequest req(fmt("http://[::1]:%d/index", port));
+ req.baseRetryTimeMs = 0;
+ ASSERT_THROW(ft->download(req), FileTransferError);
+}
}