diff --git a/src/libsysprof-capture/sysprof-collector.c b/src/libsysprof-capture/sysprof-collector.c index 2919d8b7..1d0c57b7 100644 --- a/src/libsysprof-capture/sysprof-collector.c +++ b/src/libsysprof-capture/sysprof-collector.c @@ -59,9 +59,8 @@ #endif #include -#include -#include -#include +#include +#include #include #ifdef __linux__ # include @@ -71,7 +70,9 @@ #include #include #include +#include #include +#include #include #include "mapped-ring-buffer.h" @@ -84,6 +85,14 @@ #define CREATRING "CreatRing\0" #define CREATRING_LEN 10 +#ifndef MSG_NOSIGNAL +#define MSG_NOSIGNAL 0 +#endif + +#ifndef MSG_CMSG_CLOEXEC +#define MSG_CMSG_CLOEXEC 0 +#endif + typedef struct { MappedRingBuffer *buffer; @@ -127,16 +136,186 @@ realign (size_t size) return (size + SYSPROF_CAPTURE_ALIGN - 1) & ~(SYSPROF_CAPTURE_ALIGN - 1); } +static bool +set_fd_blocking (int fd) +{ +#ifdef F_GETFL + long fcntl_flags; + fcntl_flags = fcntl (peer_fd, F_GETFL); + + if (fcntl_flags == -1) + return false; + +#ifdef O_NONBLOCK + fcntl_flags &= ~O_NONBLOCK; +#else + fcntl_flags &= ~O_NDELAY; +#endif + + if (fcntl (peer_fd, F_SETFL, fcntl_flags) == -1) + return false; + + return true; +#else + return false; +#endif +} + +static bool +block_on_poll (int fd, + int condition) +{ + struct pollfd poll_fd; + + poll_fd.fd = fd; + poll_fd.events = condition; + + return (TEMP_FAILURE_RETRY (poll (&poll_fd, 1, -1)) == 1); +} + +static ssize_t +send_blocking (int fd, + const uint8_t *buffer, + size_t buffer_len) +{ + ssize_t res; + + while ((res = TEMP_FAILURE_RETRY (send (fd, buffer, buffer_len, MSG_NOSIGNAL))) < 0) + { + int errsv = errno; + + if (errsv == EWOULDBLOCK || + errsv == EAGAIN) + { + if (!block_on_poll (fd, POLLOUT)) + return -1; + } + else + { + return -1; + } + } + + return res; +} + +static bool +send_all_blocking (int fd, + const uint8_t *buffer, + size_t buffer_len, + size_t *bytes_written) +{ + size_t _bytes_written; + + _bytes_written = 0; + while (_bytes_written < buffer_len) + { + ssize_t res = send_blocking (fd, buffer + _bytes_written, buffer_len - _bytes_written); + if (res == -1) + { + if (bytes_written != NULL) + *bytes_written = _bytes_written; + return false; + } + assert (res > 0); + + _bytes_written += res; + } + + if (bytes_written != NULL) + *bytes_written = _bytes_written; + return true; +} + +static int +receive_fd_blocking (int peer_fd) +{ + ssize_t res; + struct msghdr msg; + struct iovec one_vector; + char one_byte; + uint8_t cmsg_buffer[CMSG_SPACE (sizeof (int))]; + struct cmsghdr *cmsg; + const int *fds = NULL; + size_t n_fds = 0; + + one_vector.iov_base = &one_byte; + one_vector.iov_len = 1; + + msg.msg_name = NULL; + msg.msg_namelen = 0; + msg.msg_iov = &one_vector; + msg.msg_iovlen = 1; + msg.msg_flags = MSG_CMSG_CLOEXEC; + msg.msg_control = &cmsg_buffer; + msg.msg_controllen = sizeof (cmsg_buffer); + + while ((res = TEMP_FAILURE_RETRY (recvmsg (peer_fd, &msg, msg.msg_flags))) < 0) + { + int errsv = errno; + + if (errsv == EWOULDBLOCK || + errsv == EAGAIN) + { + if (!block_on_poll (peer_fd, POLLIN)) + return -1; + } + else + { + return -1; + } + } + + /* Decode the returned control message */ + cmsg = CMSG_FIRSTHDR (&msg); + if (cmsg == NULL) + return -1; + + if (cmsg->cmsg_level != SOL_SOCKET || + cmsg->cmsg_type != SCM_RIGHTS) + return -1; + + /* non-integer number of FDs */ + if ((cmsg->cmsg_len - ((char *)CMSG_DATA (cmsg) - (char *)cmsg)) % 4 != 0) + return -1; + + fds = (const int *) CMSG_DATA (cmsg); + n_fds = (cmsg->cmsg_len - ((char *)CMSG_DATA (cmsg) - (char *)cmsg)) / sizeof (*fds); + + /* only expecting one FD */ + if (n_fds != 1) + goto close_fds_err; + + for (size_t i = 0; i < n_fds; i++) + { + if (fds[i] < 0) + goto close_fds_err; + } + + /* only expecting one control message */ + cmsg = CMSG_NXTHDR (&msg, cmsg); + if (cmsg != NULL) + goto close_fds_err; + + return fds[0]; + +close_fds_err: + for (size_t i = 0; i < n_fds; i++) + close (fds[i]); + + return -1; +} + +/* Called with @control_fd_lock held. */ static MappedRingBuffer * request_writer (void) { - static GUnixConnection *conn; + static int peer_fd = -1; MappedRingBuffer *buffer = NULL; - if (conn == NULL) + if (peer_fd == -1) { const char *fdstr = getenv ("SYSPROF_CONTROL_FD"); - int peer_fd = -1; if (fdstr != NULL) peer_fd = atoi (fdstr); @@ -145,36 +324,24 @@ request_writer (void) if (peer_fd > 0) { - g_autoptr(GSocket) sock = NULL; + (void) set_fd_blocking (peer_fd); - g_unix_set_fd_nonblocking (peer_fd, FALSE, NULL); - - if ((sock = g_socket_new_from_fd (peer_fd, NULL))) - { - g_autoptr(GSocketConnection) scon = NULL; - - g_socket_set_blocking (sock, TRUE); - - if ((scon = g_socket_connection_factory_create_connection (sock)) && - G_IS_UNIX_CONNECTION (scon)) - conn = g_object_ref (G_UNIX_CONNECTION (scon)); - } +#ifdef SO_NOSIGPIPE + { + int opt_value = 1; + (void) setsockopt (peer_fd, SOL_SOCKET, SO_NOSIGPIPE, &opt_value, sizeof (opt_value)); + } +#endif } } - if (conn != NULL) + if (peer_fd >= 0) { - GOutputStream *out_stream; - gsize len; - - out_stream = g_io_stream_get_output_stream (G_IO_STREAM (conn)); - - if (g_output_stream_write_all (G_OUTPUT_STREAM (out_stream), CREATRING, CREATRING_LEN, &len, NULL, NULL) && - len == CREATRING_LEN) + if (send_all_blocking (peer_fd, (const uint8_t *) CREATRING, CREATRING_LEN, NULL)) { - int ring_fd = g_unix_connection_receive_fd (conn, NULL, NULL); + int ring_fd = receive_fd_blocking (peer_fd); - if (ring_fd > -1) + if (ring_fd >= 0) { buffer = mapped_ring_buffer_new_writer (ring_fd); close (ring_fd);