diff options
author | eldritch horrors <pennae@lix.systems> | 2024-03-07 06:22:44 +0100 |
---|---|---|
committer | eldritch horrors <pennae@lix.systems> | 2024-03-07 00:43:51 -0700 |
commit | bac3c5ad97d74c941a79d9afbf201fad1c54c804 (patch) | |
tree | b12c432b605015487bdbaaa34dc6d418f25a8b88 /src | |
parent | 06e92450bd87baa9a1cc06e09f59e5d79bb4b707 (diff) |
Merge pull request #9787 from obsidiansystems/bind-proc-syserror
`bind`: give same treatment as `connect` in #8544, dedup
(cherry picked from commit 28674247ec792a981741198abc190a71bb254b82)
Change-Id: I1ac5fc43fa10ec5f37a226730c3d84033fdbfd52
Diffstat (limited to 'src')
-rw-r--r-- | src/libutil/util.cc | 68 |
1 files changed, 30 insertions, 38 deletions
diff --git a/src/libutil/util.cc b/src/libutil/util.cc index 6bcb069ba..ef2654a14 100644 --- a/src/libutil/util.cc +++ b/src/libutil/util.cc @@ -2005,46 +2005,24 @@ AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode) } -void bind(int fd, const std::string & path) +static void bindConnectProcHelper( + std::string_view operationName, auto && operation, + int fd, const std::string & path) { - unlink(path.c_str()); - struct sockaddr_un addr; addr.sun_family = AF_UNIX; - if (path.size() + 1 >= sizeof(addr.sun_path)) { - Pid pid = startProcess([&]() { - Path dir = dirOf(path); - if (chdir(dir.c_str()) == -1) - throw SysError("chdir to '%s' failed", dir); - std::string base(baseNameOf(path)); - if (base.size() + 1 >= sizeof(addr.sun_path)) - throw Error("socket path '%s' is too long", base); - memcpy(addr.sun_path, base.c_str(), base.size() + 1); - if (bind(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1) - throw SysError("cannot bind to socket '%s'", path); - _exit(0); - }); - int status = pid.wait(); - if (status != 0) - throw Error("cannot bind to socket '%s'", path); - } else { - memcpy(addr.sun_path, path.c_str(), path.size() + 1); - if (bind(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1) - throw SysError("cannot bind to socket '%s'", path); - } -} - - -void connect(int fd, const std::string & path) -{ - struct sockaddr_un addr; - addr.sun_family = AF_UNIX; + // Casting between types like these legacy C library interfaces + // require is forbidden in C++. To maintain backwards + // compatibility, the implementation of the bind/connect functions + // contains some hints to the compiler that allow for this + // special case. + auto * psaddr = reinterpret_cast<struct sockaddr *>(&addr); if (path.size() + 1 >= sizeof(addr.sun_path)) { Pipe pipe; pipe.create(); - Pid pid = startProcess([&]() { + Pid pid = startProcess([&] { try { pipe.readSide.close(); Path dir = dirOf(path); @@ -2054,8 +2032,8 @@ void connect(int fd, const std::string & path) if (base.size() + 1 >= sizeof(addr.sun_path)) throw Error("socket path '%s' is too long", base); memcpy(addr.sun_path, base.c_str(), base.size() + 1); - if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1) - throw SysError("cannot connect to socket at '%s'", path); + if (operation(fd, psaddr, sizeof(addr)) == -1) + throw SysError("cannot %s to socket at '%s'", operationName, path); writeFull(pipe.writeSide.get(), "0\n"); } catch (SysError & e) { writeFull(pipe.writeSide.get(), fmt("%d\n", e.errNo)); @@ -2066,19 +2044,33 @@ void connect(int fd, const std::string & path) pipe.writeSide.close(); auto errNo = string2Int<int>(chomp(drainFD(pipe.readSide.get()))); if (!errNo || *errNo == -1) - throw Error("cannot connect to socket at '%s'", path); + throw Error("cannot %s to socket at '%s'", operationName, path); else if (*errNo > 0) { errno = *errNo; - throw SysError("cannot connect to socket at '%s'", path); + throw SysError("cannot %s to socket at '%s'", operationName, path); } } else { memcpy(addr.sun_path, path.c_str(), path.size() + 1); - if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1) - throw SysError("cannot connect to socket at '%s'", path); + if (operation(fd, psaddr, sizeof(addr)) == -1) + throw SysError("cannot %s to socket at '%s'", operationName, path); } } +void bind(int fd, const std::string & path) +{ + unlink(path.c_str()); + + bindConnectProcHelper("bind", ::bind, fd, path); +} + + +void connect(int fd, const std::string & path) +{ + bindConnectProcHelper("connect", ::connect, fd, path); +} + + std::string showBytes(uint64_t bytes) { return fmt("%.2f MiB", bytes / (1024.0 * 1024.0)); |