summary refs log tree commit diff
path: root/io_uring/net.c
diff options
context:
space:
mode:
Diffstat (limited to 'io_uring/net.c')
-rw-r--r--io_uring/net.c180
1 files changed, 161 insertions, 19 deletions
diff --git a/io_uring/net.c b/io_uring/net.c
index 5bc3440a8290..616d5f04cc74 100644
--- a/io_uring/net.c
+++ b/io_uring/net.c
@@ -325,6 +325,21 @@ int io_send(struct io_kiocb *req, unsigned int issue_flags)
 	return IOU_OK;
 }
 
+static bool io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
+{
+	unsigned long hdr;
+
+	if (check_add_overflow(sizeof(struct io_uring_recvmsg_out),
+			       (unsigned long)iomsg->namelen, &hdr))
+		return true;
+	if (check_add_overflow(hdr, iomsg->controllen, &hdr))
+		return true;
+	if (hdr > INT_MAX)
+		return true;
+
+	return false;
+}
+
 static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
 				 struct io_async_msghdr *iomsg)
 {
@@ -352,6 +367,13 @@ static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
 			sr->len = iomsg->fast_iov[0].iov_len;
 			iomsg->free_iov = NULL;
 		}
+
+		if (req->flags & REQ_F_APOLL_MULTISHOT) {
+			iomsg->namelen = msg.msg_namelen;
+			iomsg->controllen = msg.msg_controllen;
+			if (io_recvmsg_multishot_overflow(iomsg))
+				return -EOVERFLOW;
+		}
 	} else {
 		iomsg->free_iov = iomsg->fast_iov;
 		ret = __import_iovec(READ, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
@@ -399,6 +421,13 @@ static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
 			sr->len = clen;
 			iomsg->free_iov = NULL;
 		}
+
+		if (req->flags & REQ_F_APOLL_MULTISHOT) {
+			iomsg->namelen = msg.msg_namelen;
+			iomsg->controllen = msg.msg_controllen;
+			if (io_recvmsg_multishot_overflow(iomsg))
+				return -EOVERFLOW;
+		}
 	} else {
 		iomsg->free_iov = iomsg->fast_iov;
 		ret = __import_iovec(READ, (struct iovec __user *)uiov, msg.msg_iovlen,
@@ -455,8 +484,6 @@ int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 	if (sr->msg_flags & MSG_ERRQUEUE)
 		req->flags |= REQ_F_CLEAR_POLLIN;
 	if (sr->flags & IORING_RECV_MULTISHOT) {
-		if (req->opcode == IORING_OP_RECVMSG)
-			return -EINVAL;
 		if (!(req->flags & REQ_F_BUFFER_SELECT))
 			return -EINVAL;
 		if (sr->msg_flags & MSG_WAITALL)
@@ -483,12 +510,13 @@ static inline void io_recv_prep_retry(struct io_kiocb *req)
 }
 
 /*
- * Finishes io_recv
+ * Finishes io_recv and io_recvmsg.
  *
  * Returns true if it is actually finished, or false if it should run
  * again (for multishot).
  */
-static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int cflags)
+static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
+				  unsigned int cflags, bool mshot_finished)
 {
 	if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
 		io_req_set_res(req, *ret, cflags);
@@ -496,7 +524,7 @@ static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int c
 		return true;
 	}
 
-	if (*ret > 0) {
+	if (!mshot_finished) {
 		if (io_post_aux_cqe(req->ctx, req->cqe.user_data, *ret,
 				    cflags | IORING_CQE_F_MORE, false)) {
 			io_recv_prep_retry(req);
@@ -518,6 +546,90 @@ static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int c
 	return true;
 }
 
+static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg,
+				     struct io_sr_msg *sr, void __user **buf,
+				     size_t *len)
+{
+	unsigned long ubuf = (unsigned long) *buf;
+	unsigned long hdr;
+
+	hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
+		kmsg->controllen;
+	if (*len < hdr)
+		return -EFAULT;
+
+	if (kmsg->controllen) {
+		unsigned long control = ubuf + hdr - kmsg->controllen;
+
+		kmsg->msg.msg_control_user = (void *) control;
+		kmsg->msg.msg_controllen = kmsg->controllen;
+	}
+
+	sr->buf = *buf; /* stash for later copy */
+	*buf = (void *) (ubuf + hdr);
+	kmsg->payloadlen = *len = *len - hdr;
+	return 0;
+}
+
+struct io_recvmsg_multishot_hdr {
+	struct io_uring_recvmsg_out msg;
+	struct sockaddr_storage addr;
+};
+
+static int io_recvmsg_multishot(struct socket *sock, struct io_sr_msg *io,
+				struct io_async_msghdr *kmsg,
+				unsigned int flags, bool *finished)
+{
+	int err;
+	int copy_len;
+	struct io_recvmsg_multishot_hdr hdr;
+
+	if (kmsg->namelen)
+		kmsg->msg.msg_name = &hdr.addr;
+	kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
+	kmsg->msg.msg_namelen = 0;
+
+	if (sock->file->f_flags & O_NONBLOCK)
+		flags |= MSG_DONTWAIT;
+
+	err = sock_recvmsg(sock, &kmsg->msg, flags);
+	*finished = err <= 0;
+	if (err < 0)
+		return err;
+
+	hdr.msg = (struct io_uring_recvmsg_out) {
+		.controllen = kmsg->controllen - kmsg->msg.msg_controllen,
+		.flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
+	};
+
+	hdr.msg.payloadlen = err;
+	if (err > kmsg->payloadlen)
+		err = kmsg->payloadlen;
+
+	copy_len = sizeof(struct io_uring_recvmsg_out);
+	if (kmsg->msg.msg_namelen > kmsg->namelen)
+		copy_len += kmsg->namelen;
+	else
+		copy_len += kmsg->msg.msg_namelen;
+
+	/*
+	 *      "fromlen shall refer to the value before truncation.."
+	 *                      1003.1g
+	 */
+	hdr.msg.namelen = kmsg->msg.msg_namelen;
+
+	/* ensure that there is no gap between hdr and sockaddr_storage */
+	BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
+		     sizeof(struct io_uring_recvmsg_out));
+	if (copy_to_user(io->buf, &hdr, copy_len)) {
+		*finished = true;
+		return -EFAULT;
+	}
+
+	return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
+			kmsg->controllen + err;
+}
+
 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 {
 	struct io_sr_msg *sr = io_kiocb_to_cmd(req);
@@ -527,6 +639,7 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 	unsigned flags;
 	int ret, min_ret = 0;
 	bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
+	bool mshot_finished = true;
 
 	sock = sock_from_file(req->file);
 	if (unlikely(!sock))
@@ -545,16 +658,27 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 	    (sr->flags & IORING_RECVSEND_POLL_FIRST))
 		return io_setup_async_msg(req, kmsg, issue_flags);
 
+retry_multishot:
 	if (io_do_buffer_select(req)) {
 		void __user *buf;
+		size_t len = sr->len;
 
-		buf = io_buffer_select(req, &sr->len, issue_flags);
+		buf = io_buffer_select(req, &len, issue_flags);
 		if (!buf)
 			return -ENOBUFS;
+
+		if (req->flags & REQ_F_APOLL_MULTISHOT) {
+			ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
+			if (ret) {
+				io_kbuf_recycle(req, issue_flags);
+				return ret;
+			}
+		}
+
 		kmsg->fast_iov[0].iov_base = buf;
-		kmsg->fast_iov[0].iov_len = sr->len;
+		kmsg->fast_iov[0].iov_len = len;
 		iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->fast_iov, 1,
-				sr->len);
+				len);
 	}
 
 	flags = sr->msg_flags;
@@ -564,10 +688,23 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 		min_ret = iov_iter_count(&kmsg->msg.msg_iter);
 
 	kmsg->msg.msg_get_inq = 1;
-	ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg, kmsg->uaddr, flags);
+	if (req->flags & REQ_F_APOLL_MULTISHOT)
+		ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
+					   &mshot_finished);
+	else
+		ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg,
+					 kmsg->uaddr, flags);
+
 	if (ret < min_ret) {
-		if (ret == -EAGAIN && force_nonblock)
-			return io_setup_async_msg(req, kmsg, issue_flags);
+		if (ret == -EAGAIN && force_nonblock) {
+			ret = io_setup_async_msg(req, kmsg, issue_flags);
+			if (ret == -EAGAIN && (req->flags & IO_APOLL_MULTI_POLLED) ==
+					       IO_APOLL_MULTI_POLLED) {
+				io_kbuf_recycle(req, issue_flags);
+				return IOU_ISSUE_SKIP_COMPLETE;
+			}
+			return ret;
+		}
 		if (ret == -ERESTARTSYS)
 			ret = -EINTR;
 		if (ret > 0 && io_net_retry(sock, flags)) {
@@ -580,11 +717,6 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 		req_set_fail(req);
 	}
 
-	/* fast path, check for non-NULL to avoid function call */
-	if (kmsg->free_iov)
-		kfree(kmsg->free_iov);
-	io_netmsg_recycle(req, issue_flags);
-	req->flags &= ~REQ_F_NEED_CLEANUP;
 	if (ret > 0)
 		ret += sr->done_io;
 	else if (sr->done_io)
@@ -596,8 +728,18 @@ int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 	if (kmsg->msg.msg_inq)
 		cflags |= IORING_CQE_F_SOCK_NONEMPTY;
 
-	io_req_set_res(req, ret, cflags);
-	return IOU_OK;
+	if (!io_recv_finish(req, &ret, cflags, mshot_finished))
+		goto retry_multishot;
+
+	if (mshot_finished) {
+		io_netmsg_recycle(req, issue_flags);
+		/* fast path, check for non-NULL to avoid function call */
+		if (kmsg->free_iov)
+			kfree(kmsg->free_iov);
+		req->flags &= ~REQ_F_NEED_CLEANUP;
+	}
+
+	return ret;
 }
 
 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
@@ -684,7 +826,7 @@ out_free:
 	if (msg.msg_inq)
 		cflags |= IORING_CQE_F_SOCK_NONEMPTY;
 
-	if (!io_recv_finish(req, &ret, cflags))
+	if (!io_recv_finish(req, &ret, cflags, ret <= 0))
 		goto retry_multishot;
 
 	return ret;