diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/libmain/shared.hh | 6 | ||||
-rw-r--r-- | src/libstore/remote-store.cc | 3 | ||||
-rw-r--r-- | src/libutil/serialise.cc | 7 | ||||
-rw-r--r-- | src/libutil/types.hh | 2 | ||||
-rw-r--r-- | src/libutil/util.cc | 17 | ||||
-rw-r--r-- | src/libutil/util.hh | 17 | ||||
-rw-r--r-- | src/nix-worker/main.cc | 183 |
7 files changed, 182 insertions, 53 deletions
diff --git a/src/libmain/shared.hh b/src/libmain/shared.hh index 2c574d148..fa45645fe 100644 --- a/src/libmain/shared.hh +++ b/src/libmain/shared.hh @@ -3,6 +3,8 @@ #include "types.hh" +#include <signal.h> + /* These are not implemented here, but must be implemented by a program linking against libmain. */ @@ -27,6 +29,10 @@ void printGCWarning(); /* Whether we're running setuid. */ extern bool setuidMode; +extern volatile ::sig_atomic_t blockInt; + +MakeError(UsageError, nix::Error) + } diff --git a/src/libstore/remote-store.cc b/src/libstore/remote-store.cc index 4d4189be0..b9ed1fdbc 100644 --- a/src/libstore/remote-store.cc +++ b/src/libstore/remote-store.cc @@ -39,10 +39,11 @@ RemoteStore::RemoteStore() /* Send the magic greeting, check for the reply. */ try { - processStderr(); writeInt(WORKER_MAGIC_1, to); + writeInt(verbosity, to); unsigned int magic = readInt(from); if (magic != WORKER_MAGIC_2) throw Error("protocol mismatch"); + processStderr(); } catch (Error & e) { throw Error(format("cannot start worker (%1%)") % e.msg()); diff --git a/src/libutil/serialise.cc b/src/libutil/serialise.cc index 969f638ef..c0e1c17af 100644 --- a/src/libutil/serialise.cc +++ b/src/libutil/serialise.cc @@ -85,10 +85,11 @@ unsigned int readInt(Source & source) string readString(Source & source) { unsigned int len = readInt(source); - char buf[len]; - source((unsigned char *) buf, len); + unsigned char * buf = new unsigned char[len]; + AutoDeleteArray<unsigned char> d(buf); + source(buf, len); readPadding(len, source); - return string(buf, len); + return string((char *) buf, len); } diff --git a/src/libutil/types.hh b/src/libutil/types.hh index 1de378961..257871a82 100644 --- a/src/libutil/types.hh +++ b/src/libutil/types.hh @@ -44,8 +44,6 @@ public: newClass(const format & f) : superClass(f) { }; \ }; -MakeError(UsageError, Error) - typedef list<string> Strings; typedef set<string> StringSet; diff --git a/src/libutil/util.cc b/src/libutil/util.cc index 7c1138720..08385e5d9 100644 --- a/src/libutil/util.cc +++ b/src/libutil/util.cc @@ -191,18 +191,6 @@ Strings readDirectory(const Path & path) } -template <class T> -struct AutoDeleteArray -{ - T * p; - AutoDeleteArray(T * p) : p(p) { } - ~AutoDeleteArray() - { - delete [] p; - } -}; - - string readFile(int fd) { struct stat st; @@ -468,7 +456,7 @@ void readFull(int fd, unsigned char * buf, size_t count) if (errno == EINTR) continue; throw SysError("reading from file"); } - if (res == 0) throw Error("unexpected end-of-file"); + if (res == 0) throw EndOfFile("unexpected end-of-file"); count -= res; buf += res; } @@ -707,6 +695,7 @@ int Pid::wait(bool block) if (res == 0 && !block) return -1; if (errno != EINTR) throw SysError("cannot get child exit status"); + checkInterrupt(); } } @@ -793,7 +782,7 @@ void _interrupted() kills the program! */ if (!std::uncaught_exception()) { _isInterrupted = 0; - throw Error("interrupted by the user"); + throw Interrupted("interrupted by the user"); } } diff --git a/src/libutil/util.hh b/src/libutil/util.hh index 0d39ffee9..b88508dec 100644 --- a/src/libutil/util.hh +++ b/src/libutil/util.hh @@ -139,6 +139,8 @@ extern void (*writeToStderr) (const unsigned char * buf, size_t count); void readFull(int fd, unsigned char * buf, size_t count); void writeFull(int fd, const unsigned char * buf, size_t count); +MakeError(EndOfFile, Error) + /* Read a file descriptor until EOF occurs. */ string drainFD(int fd); @@ -147,6 +149,19 @@ string drainFD(int fd); /* Automatic cleanup of resources. */ + +template <class T> +struct AutoDeleteArray +{ + T * p; + AutoDeleteArray(T * p) : p(p) { } + ~AutoDeleteArray() + { + delete [] p; + } +}; + + class AutoDelete { string path; @@ -229,6 +244,8 @@ void inline checkInterrupt() if (_isInterrupted) _interrupted(); } +MakeError(Interrupted, Error) + /* String packing / unpacking. */ string packStrings(const Strings & strings); diff --git a/src/nix-worker/main.cc b/src/nix-worker/main.cc index c8576ddb6..d104ea840 100644 --- a/src/nix-worker/main.cc +++ b/src/nix-worker/main.cc @@ -9,6 +9,10 @@ #include <iostream> #include <unistd.h> #include <signal.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <sys/socket.h> +#include <sys/un.h> #include <fcntl.h> using namespace nix; @@ -43,9 +47,6 @@ bool canSendStderr; socket. */ static void tunnelStderr(const unsigned char * buf, size_t count) { - if (canSendStderr) - writeFull(STDERR_FILENO, (unsigned char *) "L: ", 3); - writeFull(STDERR_FILENO, buf, count); if (canSendStderr) { try { writeInt(STDERR_NEXT, to); @@ -65,12 +66,20 @@ static void tunnelStderr(const unsigned char * buf, size_t count) socket. This handler is enabled at precisely those moments in the protocol when we're doing work and the client is supposed to be quiet. Thus, if we get a SIGIO signal, it means that the client - has quit. So we should quit as well. */ + has quit. So we should quit as well. + + Too bad most operating systems don't support the POLL_HUP value for + si_code in siginfo_t. That would make most of the SIGIO complexity + unnecessary, i.e., we could just enable SIGIO all the time and + wouldn't have to worry about races. */ static void sigioHandler(int sigNo) { - _isInterrupted = 1; - canSendStderr = false; - write(STDERR_FILENO, "SIGIO\n", 6); + if (!blockInt) { + _isInterrupted = 1; + blockInt = 1; + canSendStderr = false; + write(STDERR_FILENO, "SIGIO\n", 6); + } } @@ -97,14 +106,14 @@ static void startWork() fd_set fds; FD_ZERO(&fds); - FD_SET(STDIN_FILENO, &fds); + FD_SET(from.fd, &fds); - if (select(STDIN_FILENO + 1, &fds, 0, 0, &timeout) == -1) + if (select(from.fd + 1, &fds, 0, 0, &timeout) == -1) throw SysError("select()"); - if (FD_ISSET(STDIN_FILENO, &fds)) { + if (FD_ISSET(from.fd, &fds)) { char c; - if (read(STDIN_FILENO, &c, 1) != 0) + if (read(from.fd, &c, 1) != 0) throw Error("EOF expected (protocol error?)"); _isInterrupted = 1; checkInterrupt(); @@ -114,7 +123,7 @@ static void startWork() /* stopWork() means that we're done; stop sending stderr to the client. */ -static void stopWork() +static void stopWork(bool success = true, const string & msg = "") { /* Stop handling async client death; we're going to a state where we're either sending or receiving from the client, so we'll be @@ -123,7 +132,13 @@ static void stopWork() throw SysError("ignoring SIGIO"); canSendStderr = false; - writeInt(STDERR_LAST, to); + + if (success) + writeInt(STDERR_LAST, to); + else { + writeInt(STDERR_ERROR, to); + writeString(msg, to); + } } @@ -237,11 +252,17 @@ static void processConnection() /* Allow us to receive SIGIO for events on the client socket. */ signal(SIGIO, SIG_IGN); - if (fcntl(STDIN_FILENO, F_SETOWN, getpid()) == -1) + if (fcntl(from.fd, F_SETOWN, getpid()) == -1) throw SysError("F_SETOWN"); - if (fcntl(STDIN_FILENO, F_SETFL, fcntl(STDIN_FILENO, F_GETFL, 0) | FASYNC) == -1) + if (fcntl(from.fd, F_SETFL, fcntl(from.fd, F_GETFL, 0) | FASYNC) == -1) throw SysError("F_SETFL"); + /* Exchange the greeting. */ + unsigned int magic = readInt(from); + if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch"); + verbosity = (Verbosity) readInt(from); + writeInt(WORKER_MAGIC_2, to); + /* Send startup error messages to the client. */ startWork(); @@ -258,40 +279,137 @@ static void processConnection() stopWork(); } catch (Error & e) { - writeInt(STDERR_ERROR, to); - writeString(e.msg(), to); + stopWork(false, e.msg()); return; } - /* Exchange the greeting. */ - unsigned int magic = readInt(from); - if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch"); - writeInt(WORKER_MAGIC_2, to); - debug("greeting exchanged"); - /* Process client requests. */ - bool quit = false; - unsigned int opCount = 0; - do { - WorkerOp op = (WorkerOp) readInt(from); + while (true) { + WorkerOp op; + try { + op = (WorkerOp) readInt(from); + } catch (EndOfFile & e) { + break; + } opCount++; try { performOp(from, to, op); } catch (Error & e) { - writeInt(STDERR_ERROR, to); - writeString(e.msg(), to); + stopWork(false, e.msg()); } - - } while (!quit); + }; printMsg(lvlError, format("%1% worker operations") % opCount); } +static void setSigChldAction(bool ignore) +{ + struct sigaction act, oact; + act.sa_handler = ignore ? SIG_IGN : SIG_DFL; + sigfillset(&act.sa_mask); + act.sa_flags = 0; + if (sigaction(SIGCHLD, &act, &oact)) + throw SysError("setting SIGCHLD handler"); +} + + +static void daemonLoop() +{ + /* Get rid of children automatically; don't let them become + zombies. */ + setSigChldAction(true); + + /* Create and bind to a Unix domain socket. */ + AutoCloseFD fdSocket = socket(PF_UNIX, SOCK_STREAM, 0); + if (fdSocket == -1) + throw SysError("cannot create Unix domain socket"); + + string socketPath = nixStateDir + DEFAULT_SOCKET_PATH; + + struct sockaddr_un addr; + addr.sun_family = AF_UNIX; + if (socketPath.size() >= sizeof(addr.sun_path)) + throw Error(format("socket path `%1%' is too long") % socketPath); + strcpy(addr.sun_path, socketPath.c_str()); + + unlink(socketPath.c_str()); + + /* Make sure that the socket is created with 0666 permission + (everybody can connect). */ + mode_t oldMode = umask(0111); + int res = bind(fdSocket, (struct sockaddr *) &addr, sizeof(addr)); + umask(oldMode); + if (res == -1) + throw SysError(format("cannot bind to socket `%1%'") % socketPath); + + if (listen(fdSocket, 5) == -1) + throw SysError(format("cannot listen on socket `%1%'") % socketPath); + + /* Loop accepting connections. */ + while (1) { + + try { + /* Important: the server process *cannot* open the + Berkeley DB environment, because it doesn't like forks + very much. */ + assert(!store); + + /* Accept a connection. */ + struct sockaddr_un remoteAddr; + socklen_t remoteAddrLen = sizeof(remoteAddr); + + AutoCloseFD remote = accept(fdSocket, + (struct sockaddr *) &remoteAddr, &remoteAddrLen); + checkInterrupt(); + if (remote == -1) + throw SysError("accepting connection"); + + printMsg(lvlInfo, format("accepted connection %1%") % remote); + + /* Fork a child to handle the connection. */ + pid_t child; + child = fork(); + + switch (child) { + + case -1: + throw SysError("unable to fork"); + + case 0: + try { /* child */ + + /* Background the worker. */ + if (setsid() == -1) + throw SysError(format("creating a new session")); + + /* Restore normal handling of SIGCHLD. */ + setSigChldAction(false); + + /* Handle the connection. */ + from.fd = remote; + to.fd = remote; + processConnection(); + + } catch (std::exception & e) { + std::cerr << format("child error: %1%\n") % e.what(); + } + exit(0); + } + + } catch (Interrupted & e) { + throw; + } catch (Error & e) { + printMsg(lvlError, format("error processing connection: %1%") % e.msg()); + } + } +} + + void run(Strings args) { bool slave = false; @@ -315,8 +433,7 @@ void run(Strings args) else if (daemon) { if (setuidMode) throw Error("daemon cannot be started in setuid mode"); - - throw Error("daemon mode not implemented"); + daemonLoop(); } else |