summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--fs/io_uring.c177
-rw-r--r--include/linux/io_uring.h5
-rw-r--r--include/linux/socket.h1
-rw-r--r--include/trace/events/io_uring.h36
-rw-r--r--include/uapi/linux/io_uring.h3
-rw-r--r--net/socket.c52
6 files changed, 248 insertions, 26 deletions
diff --git a/fs/io_uring.c b/fs/io_uring.c
index e4c9f776919b..d9529275a030 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -600,6 +600,16 @@ struct io_accept {
 	unsigned long			nofile;
 };
 
+struct io_socket {
+	struct file			*file;
+	int				domain;
+	int				type;
+	int				protocol;
+	int				flags;
+	u32				file_slot;
+	unsigned long			nofile;
+};
+
 struct io_sync {
 	struct file			*file;
 	loff_t				len;
@@ -981,6 +991,7 @@ struct io_kiocb {
 		struct io_hardlink	hardlink;
 		struct io_msg		msg;
 		struct io_xattr		xattr;
+		struct io_socket	sock;
 	};
 
 	u8				opcode;
@@ -1297,6 +1308,9 @@ static const struct io_op_def io_op_defs[] = {
 		.needs_file = 1
 	},
 	[IORING_OP_GETXATTR] = {},
+	[IORING_OP_SOCKET] = {
+		.audit_skip		= 1,
+	},
 };
 
 /* requests with any of those set should undergo io_disarm_next() */
@@ -1341,6 +1355,107 @@ static struct kmem_cache *req_cachep;
 
 static const struct file_operations io_uring_fops;
 
+const char *io_uring_get_opcode(u8 opcode)
+{
+	switch ((enum io_uring_op)opcode) {
+	case IORING_OP_NOP:
+		return "NOP";
+	case IORING_OP_READV:
+		return "READV";
+	case IORING_OP_WRITEV:
+		return "WRITEV";
+	case IORING_OP_FSYNC:
+		return "FSYNC";
+	case IORING_OP_READ_FIXED:
+		return "READ_FIXED";
+	case IORING_OP_WRITE_FIXED:
+		return "WRITE_FIXED";
+	case IORING_OP_POLL_ADD:
+		return "POLL_ADD";
+	case IORING_OP_POLL_REMOVE:
+		return "POLL_REMOVE";
+	case IORING_OP_SYNC_FILE_RANGE:
+		return "SYNC_FILE_RANGE";
+	case IORING_OP_SENDMSG:
+		return "SENDMSG";
+	case IORING_OP_RECVMSG:
+		return "RECVMSG";
+	case IORING_OP_TIMEOUT:
+		return "TIMEOUT";
+	case IORING_OP_TIMEOUT_REMOVE:
+		return "TIMEOUT_REMOVE";
+	case IORING_OP_ACCEPT:
+		return "ACCEPT";
+	case IORING_OP_ASYNC_CANCEL:
+		return "ASYNC_CANCEL";
+	case IORING_OP_LINK_TIMEOUT:
+		return "LINK_TIMEOUT";
+	case IORING_OP_CONNECT:
+		return "CONNECT";
+	case IORING_OP_FALLOCATE:
+		return "FALLOCATE";
+	case IORING_OP_OPENAT:
+		return "OPENAT";
+	case IORING_OP_CLOSE:
+		return "CLOSE";
+	case IORING_OP_FILES_UPDATE:
+		return "FILES_UPDATE";
+	case IORING_OP_STATX:
+		return "STATX";
+	case IORING_OP_READ:
+		return "READ";
+	case IORING_OP_WRITE:
+		return "WRITE";
+	case IORING_OP_FADVISE:
+		return "FADVISE";
+	case IORING_OP_MADVISE:
+		return "MADVISE";
+	case IORING_OP_SEND:
+		return "SEND";
+	case IORING_OP_RECV:
+		return "RECV";
+	case IORING_OP_OPENAT2:
+		return "OPENAT2";
+	case IORING_OP_EPOLL_CTL:
+		return "EPOLL_CTL";
+	case IORING_OP_SPLICE:
+		return "SPLICE";
+	case IORING_OP_PROVIDE_BUFFERS:
+		return "PROVIDE_BUFFERS";
+	case IORING_OP_REMOVE_BUFFERS:
+		return "REMOVE_BUFFERS";
+	case IORING_OP_TEE:
+		return "TEE";
+	case IORING_OP_SHUTDOWN:
+		return "SHUTDOWN";
+	case IORING_OP_RENAMEAT:
+		return "RENAMEAT";
+	case IORING_OP_UNLINKAT:
+		return "UNLINKAT";
+	case IORING_OP_MKDIRAT:
+		return "MKDIRAT";
+	case IORING_OP_SYMLINKAT:
+		return "SYMLINKAT";
+	case IORING_OP_LINKAT:
+		return "LINKAT";
+	case IORING_OP_MSG_RING:
+		return "MSG_RING";
+	case IORING_OP_FSETXATTR:
+		return "FSETXATTR";
+	case IORING_OP_SETXATTR:
+		return "SETXATTR";
+	case IORING_OP_FGETXATTR:
+		return "FGETXATTR";
+	case IORING_OP_GETXATTR:
+		return "GETXATTR";
+	case IORING_OP_SOCKET:
+		return "SOCKET";
+	case IORING_OP_LAST:
+		return "INVALID";
+	}
+	return "INVALID";
+}
+
 struct sock *io_uring_get_socket(struct file *file)
 {
 #if defined(CONFIG_UNIX)
@@ -6237,6 +6352,62 @@ retry:
 	return ret;
 }
 
+static int io_socket_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
+{
+	struct io_socket *sock = &req->sock;
+
+	if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
+		return -EINVAL;
+	if (sqe->ioprio || sqe->addr || sqe->rw_flags || sqe->buf_index)
+		return -EINVAL;
+
+	sock->domain = READ_ONCE(sqe->fd);
+	sock->type = READ_ONCE(sqe->off);
+	sock->protocol = READ_ONCE(sqe->len);
+	sock->file_slot = READ_ONCE(sqe->file_index);
+	sock->nofile = rlimit(RLIMIT_NOFILE);
+
+	sock->flags = sock->type & ~SOCK_TYPE_MASK;
+	if (sock->file_slot && (sock->flags & SOCK_CLOEXEC))
+		return -EINVAL;
+	if (sock->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
+		return -EINVAL;
+	return 0;
+}
+
+static int io_socket(struct io_kiocb *req, unsigned int issue_flags)
+{
+	struct io_socket *sock = &req->sock;
+	bool fixed = !!sock->file_slot;
+	struct file *file;
+	int ret, fd;
+
+	if (!fixed) {
+		fd = __get_unused_fd_flags(sock->flags, sock->nofile);
+		if (unlikely(fd < 0))
+			return fd;
+	}
+	file = __sys_socket_file(sock->domain, sock->type, sock->protocol);
+	if (IS_ERR(file)) {
+		if (!fixed)
+			put_unused_fd(fd);
+		ret = PTR_ERR(file);
+		if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
+			return -EAGAIN;
+		if (ret == -ERESTARTSYS)
+			ret = -EINTR;
+		req_set_fail(req);
+	} else if (!fixed) {
+		fd_install(fd, file);
+		ret = fd;
+	} else {
+		ret = io_install_fixed_file(req, file, issue_flags,
+					    sock->file_slot - 1);
+	}
+	__io_req_complete(req, issue_flags, ret, 0);
+	return 0;
+}
+
 static int io_connect_prep_async(struct io_kiocb *req)
 {
 	struct io_async_connect *io = req->async_data;
@@ -6322,6 +6493,7 @@ IO_NETOP_PREP_ASYNC(sendmsg);
 IO_NETOP_PREP_ASYNC(recvmsg);
 IO_NETOP_PREP_ASYNC(connect);
 IO_NETOP_PREP(accept);
+IO_NETOP_PREP(socket);
 IO_NETOP_FN(send);
 IO_NETOP_FN(recv);
 #endif /* CONFIG_NET */
@@ -7651,6 +7823,8 @@ static int io_req_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 		return io_fgetxattr_prep(req, sqe);
 	case IORING_OP_GETXATTR:
 		return io_getxattr_prep(req, sqe);
+	case IORING_OP_SOCKET:
+		return io_socket_prep(req, sqe);
 	}
 
 	printk_once(KERN_WARNING "io_uring: unhandled opcode %d\n",
@@ -7974,6 +8148,9 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
 	case IORING_OP_GETXATTR:
 		ret = io_getxattr(req, issue_flags);
 		break;
+	case IORING_OP_SOCKET:
+		ret = io_socket(req, issue_flags);
+		break;
 	default:
 		ret = -EINVAL;
 		break;
diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
index 1814e698d861..24651c229ed2 100644
--- a/include/linux/io_uring.h
+++ b/include/linux/io_uring.h
@@ -10,6 +10,7 @@ struct sock *io_uring_get_socket(struct file *file);
 void __io_uring_cancel(bool cancel_all);
 void __io_uring_free(struct task_struct *tsk);
 void io_uring_unreg_ringfd(void);
+const char *io_uring_get_opcode(u8 opcode);
 
 static inline void io_uring_files_cancel(void)
 {
@@ -42,6 +43,10 @@ static inline void io_uring_files_cancel(void)
 static inline void io_uring_free(struct task_struct *tsk)
 {
 }
+static inline const char *io_uring_get_opcode(u8 opcode)
+{
+	return "";
+}
 #endif
 
 #endif
diff --git a/include/linux/socket.h b/include/linux/socket.h
index 6f85f5d957ef..a1882e1e71d2 100644
--- a/include/linux/socket.h
+++ b/include/linux/socket.h
@@ -434,6 +434,7 @@ extern struct file *do_accept(struct file *file, unsigned file_flags,
 extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
 			 int __user *upeer_addrlen, int flags);
 extern int __sys_socket(int family, int type, int protocol);
+extern struct file *__sys_socket_file(int family, int type, int protocol);
 extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen);
 extern int __sys_connect_file(struct file *file, struct sockaddr_storage *addr,
 			      int addrlen, int file_flags);
diff --git a/include/trace/events/io_uring.h b/include/trace/events/io_uring.h
index f74edfc75f41..610e493234db 100644
--- a/include/trace/events/io_uring.h
+++ b/include/trace/events/io_uring.h
@@ -7,6 +7,7 @@
 
 #include <linux/tracepoint.h>
 #include <uapi/linux/io_uring.h>
+#include <linux/io_uring.h>
 
 struct io_wq_work;
 
@@ -169,8 +170,9 @@ TRACE_EVENT(io_uring_queue_async_work,
 		__entry->rw		= rw;
 	),
 
-	TP_printk("ring %p, request %p, user_data 0x%llx, opcode %d, flags 0x%x, %s queue, work %p",
-		__entry->ctx, __entry->req, __entry->user_data, __entry->opcode,
+	TP_printk("ring %p, request %p, user_data 0x%llx, opcode %s, flags 0x%x, %s queue, work %p",
+		__entry->ctx, __entry->req, __entry->user_data,
+		io_uring_get_opcode(__entry->opcode),
 		__entry->flags, __entry->rw ? "hashed" : "normal", __entry->work)
 );
 
@@ -205,8 +207,9 @@ TRACE_EVENT(io_uring_defer,
 		__entry->opcode	= opcode;
 	),
 
-	TP_printk("ring %p, request %p, user_data 0x%llx, opcode %d",
-		__entry->ctx, __entry->req, __entry->data, __entry->opcode)
+	TP_printk("ring %p, request %p, user_data 0x%llx, opcode %s",
+		__entry->ctx, __entry->req, __entry->data,
+		io_uring_get_opcode(__entry->opcode))
 );
 
 /**
@@ -305,9 +308,9 @@ TRACE_EVENT(io_uring_fail_link,
 		__entry->link		= link;
 	),
 
-	TP_printk("ring %p, request %p, user_data 0x%llx, opcode %d, link %p",
-		__entry->ctx, __entry->req, __entry->user_data, __entry->opcode,
-		__entry->link)
+	TP_printk("ring %p, request %p, user_data 0x%llx, opcode %s, link %p",
+		__entry->ctx, __entry->req, __entry->user_data,
+		io_uring_get_opcode(__entry->opcode), __entry->link)
 );
 
 /**
@@ -389,9 +392,9 @@ TRACE_EVENT(io_uring_submit_sqe,
 		__entry->sq_thread	= sq_thread;
 	),
 
-	TP_printk("ring %p, req %p, user_data 0x%llx, opcode %d, flags 0x%x, "
+	TP_printk("ring %p, req %p, user_data 0x%llx, opcode %s, flags 0x%x, "
 		  "non block %d, sq_thread %d", __entry->ctx, __entry->req,
-		  __entry->user_data, __entry->opcode,
+		  __entry->user_data, io_uring_get_opcode(__entry->opcode),
 		  __entry->flags, __entry->force_nonblock, __entry->sq_thread)
 );
 
@@ -433,8 +436,9 @@ TRACE_EVENT(io_uring_poll_arm,
 		__entry->events		= events;
 	),
 
-	TP_printk("ring %p, req %p, user_data 0x%llx, opcode %d, mask 0x%x, events 0x%x",
-		  __entry->ctx, __entry->req, __entry->user_data, __entry->opcode,
+	TP_printk("ring %p, req %p, user_data 0x%llx, opcode %s, mask 0x%x, events 0x%x",
+		  __entry->ctx, __entry->req, __entry->user_data,
+		  io_uring_get_opcode(__entry->opcode),
 		  __entry->mask, __entry->events)
 );
 
@@ -470,8 +474,9 @@ TRACE_EVENT(io_uring_task_add,
 		__entry->mask		= mask;
 	),
 
-	TP_printk("ring %p, req %p, user_data 0x%llx, opcode %d, mask %x",
-		__entry->ctx, __entry->req, __entry->user_data, __entry->opcode,
+	TP_printk("ring %p, req %p, user_data 0x%llx, opcode %s, mask %x",
+		__entry->ctx, __entry->req, __entry->user_data,
+		io_uring_get_opcode(__entry->opcode),
 		__entry->mask)
 );
 
@@ -530,12 +535,13 @@ TRACE_EVENT(io_uring_req_failed,
 	),
 
 	TP_printk("ring %p, req %p, user_data 0x%llx, "
-		  "op %d, flags 0x%x, prio=%d, off=%llu, addr=%llu, "
+		  "opcode %s, flags 0x%x, prio=%d, off=%llu, addr=%llu, "
 		  "len=%u, rw_flags=0x%x, buf_index=%d, "
 		  "personality=%d, file_index=%d, pad=0x%llx, addr3=%llx, "
 		  "error=%d",
 		  __entry->ctx, __entry->req, __entry->user_data,
-		  __entry->opcode, __entry->flags, __entry->ioprio,
+		  io_uring_get_opcode(__entry->opcode),
+		  __entry->flags, __entry->ioprio,
 		  (unsigned long long)__entry->off,
 		  (unsigned long long) __entry->addr, __entry->len,
 		  __entry->op_flags,
diff --git a/include/uapi/linux/io_uring.h b/include/uapi/linux/io_uring.h
index 199bbf73a752..804916814e51 100644
--- a/include/uapi/linux/io_uring.h
+++ b/include/uapi/linux/io_uring.h
@@ -128,7 +128,7 @@ enum {
  */
 #define IORING_SETUP_TASKRUN_FLAG	(1U << 9)
 
-enum {
+enum io_uring_op {
 	IORING_OP_NOP,
 	IORING_OP_READV,
 	IORING_OP_WRITEV,
@@ -174,6 +174,7 @@ enum {
 	IORING_OP_SETXATTR,
 	IORING_OP_FGETXATTR,
 	IORING_OP_GETXATTR,
+	IORING_OP_SOCKET,
 
 	/* this goes last, obviously */
 	IORING_OP_LAST,
diff --git a/net/socket.c b/net/socket.c
index 6887840682bb..bb6a1a12fbde 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -504,7 +504,7 @@ static int sock_map_fd(struct socket *sock, int flags)
 struct socket *sock_from_file(struct file *file)
 {
 	if (file->f_op == &socket_file_ops)
-		return file->private_data;	/* set in sock_map_fd */
+		return file->private_data;	/* set in sock_alloc_file */
 
 	return NULL;
 }
@@ -1538,11 +1538,10 @@ int sock_create_kern(struct net *net, int family, int type, int protocol, struct
 }
 EXPORT_SYMBOL(sock_create_kern);
 
-int __sys_socket(int family, int type, int protocol)
+static struct socket *__sys_socket_create(int family, int type, int protocol)
 {
-	int retval;
 	struct socket *sock;
-	int flags;
+	int retval;
 
 	/* Check the SOCK_* constants for consistency.  */
 	BUILD_BUG_ON(SOCK_CLOEXEC != O_CLOEXEC);
@@ -1550,17 +1549,50 @@ int __sys_socket(int family, int type, int protocol)
 	BUILD_BUG_ON(SOCK_CLOEXEC & SOCK_TYPE_MASK);
 	BUILD_BUG_ON(SOCK_NONBLOCK & SOCK_TYPE_MASK);
 
-	flags = type & ~SOCK_TYPE_MASK;
-	if (flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
-		return -EINVAL;
+	if ((type & ~SOCK_TYPE_MASK) & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
+		return ERR_PTR(-EINVAL);
 	type &= SOCK_TYPE_MASK;
 
+	retval = sock_create(family, type, protocol, &sock);
+	if (retval < 0)
+		return ERR_PTR(retval);
+
+	return sock;
+}
+
+struct file *__sys_socket_file(int family, int type, int protocol)
+{
+	struct socket *sock;
+	struct file *file;
+	int flags;
+
+	sock = __sys_socket_create(family, type, protocol);
+	if (IS_ERR(sock))
+		return ERR_CAST(sock);
+
+	flags = type & ~SOCK_TYPE_MASK;
 	if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK))
 		flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
 
-	retval = sock_create(family, type, protocol, &sock);
-	if (retval < 0)
-		return retval;
+	file = sock_alloc_file(sock, flags, NULL);
+	if (IS_ERR(file))
+		sock_release(sock);
+
+	return file;
+}
+
+int __sys_socket(int family, int type, int protocol)
+{
+	struct socket *sock;
+	int flags;
+
+	sock = __sys_socket_create(family, type, protocol);
+	if (IS_ERR(sock))
+		return PTR_ERR(sock);
+
+	flags = type & ~SOCK_TYPE_MASK;
+	if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK))
+		flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
 
 	return sock_map_fd(sock, flags & (O_CLOEXEC | O_NONBLOCK));
 }