summary refs log tree commit diff
path: root/fs/io_uring.c
diff options
context:
space:
mode:
Diffstat (limited to 'fs/io_uring.c')
-rw-r--r--fs/io_uring.c837
1 files changed, 455 insertions, 382 deletions
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 92c25b5f1349..a4bce17af506 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -258,12 +258,10 @@ enum {
 
 struct io_sq_data {
 	refcount_t		refs;
-	struct mutex		lock;
+	struct rw_semaphore	rw_lock;
 
 	/* ctx's that are using this sqd */
 	struct list_head	ctx_list;
-	struct list_head	ctx_new_list;
-	struct mutex		ctx_lock;
 
 	struct task_struct	*thread;
 	struct wait_queue_head	wait;
@@ -271,10 +269,9 @@ struct io_sq_data {
 	unsigned		sq_thread_idle;
 	int			sq_cpu;
 	pid_t			task_pid;
+	pid_t			task_tgid;
 
 	unsigned long		state;
-	struct completion	startup;
-	struct completion	parked;
 	struct completion	exited;
 };
 
@@ -336,7 +333,6 @@ struct io_ring_ctx {
 		unsigned int		drain_next: 1;
 		unsigned int		eventfd_async: 1;
 		unsigned int		restricted: 1;
-		unsigned int		sqo_exec: 1;
 
 		/*
 		 * Ring buffer of indices into array of io_uring_sqe, which is
@@ -380,6 +376,7 @@ struct io_ring_ctx {
 	/* Only used for accounting purposes */
 	struct mm_struct	*mm_account;
 
+	const struct cred	*sq_creds;	/* cred used for __io_sq_thread() */
 	struct io_sq_data	*sq_data;	/* if using sq thread polling */
 
 	struct wait_queue_head	sqo_sq_wait;
@@ -400,7 +397,6 @@ struct io_ring_ctx {
 	struct user_struct	*user;
 
 	struct completion	ref_comp;
-	struct completion	sq_thread_comp;
 
 #if defined(CONFIG_UNIX)
 	struct socket		*ring_sock;
@@ -408,7 +404,8 @@ struct io_ring_ctx {
 
 	struct idr		io_buffer_idr;
 
-	struct idr		personality_idr;
+	struct xarray		personalities;
+	u32			pers_next;
 
 	struct {
 		unsigned		cached_cq_tail;
@@ -454,6 +451,7 @@ struct io_ring_ctx {
 
 	/* Keep this last, we don't need it for the fast path */
 	struct work_struct		exit_work;
+	struct list_head		tctx_list;
 };
 
 /*
@@ -805,6 +803,12 @@ struct io_kiocb {
 	struct io_wq_work		work;
 };
 
+struct io_tctx_node {
+	struct list_head	ctx_node;
+	struct task_struct	*task;
+	struct io_ring_ctx	*ctx;
+};
+
 struct io_defer_entry {
 	struct list_head	list;
 	struct io_kiocb		*req;
@@ -979,6 +983,8 @@ static const struct io_op_def io_op_defs[] = {
 	[IORING_OP_UNLINKAT] = {},
 };
 
+static bool io_disarm_next(struct io_kiocb *req);
+static void io_uring_del_task_file(unsigned long index);
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 					 struct task_struct *task,
 					 struct files_struct *files);
@@ -1129,9 +1135,8 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 	init_waitqueue_head(&ctx->cq_wait);
 	INIT_LIST_HEAD(&ctx->cq_overflow_list);
 	init_completion(&ctx->ref_comp);
-	init_completion(&ctx->sq_thread_comp);
 	idr_init(&ctx->io_buffer_idr);
-	idr_init(&ctx->personality_idr);
+	xa_init_flags(&ctx->personalities, XA_FLAGS_ALLOC1);
 	mutex_init(&ctx->uring_lock);
 	init_waitqueue_head(&ctx->wait);
 	spin_lock_init(&ctx->completion_lock);
@@ -1144,6 +1149,7 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 	INIT_LIST_HEAD(&ctx->rsrc_ref_list);
 	INIT_DELAYED_WORK(&ctx->rsrc_put_work, io_rsrc_put_work);
 	init_llist_head(&ctx->rsrc_put_llist);
+	INIT_LIST_HEAD(&ctx->tctx_list);
 	INIT_LIST_HEAD(&ctx->submit_state.comp.free_list);
 	INIT_LIST_HEAD(&ctx->submit_state.comp.locked_free_list);
 	return ctx;
@@ -1183,6 +1189,9 @@ static void io_prep_async_work(struct io_kiocb *req)
 	const struct io_op_def *def = &io_op_defs[req->opcode];
 	struct io_ring_ctx *ctx = req->ctx;
 
+	if (!req->work.creds)
+		req->work.creds = get_current_cred();
+
 	if (req->flags & REQ_F_FORCE_ASYNC)
 		req->work.flags |= IO_WQ_WORK_CONCURRENT;
 
@@ -1514,15 +1523,14 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res)
 	__io_cqring_fill_event(req, res, 0);
 }
 
-static inline void io_req_complete_post(struct io_kiocb *req, long res,
-					unsigned int cflags)
+static void io_req_complete_post(struct io_kiocb *req, long res,
+				 unsigned int cflags)
 {
 	struct io_ring_ctx *ctx = req->ctx;
 	unsigned long flags;
 
 	spin_lock_irqsave(&ctx->completion_lock, flags);
 	__io_cqring_fill_event(req, res, cflags);
-	io_commit_cqring(ctx);
 	/*
 	 * If we're the last reference to this request, add to our locked
 	 * free_list cache.
@@ -1530,19 +1538,26 @@ static inline void io_req_complete_post(struct io_kiocb *req, long res,
 	if (refcount_dec_and_test(&req->refs)) {
 		struct io_comp_state *cs = &ctx->submit_state.comp;
 
+		if (req->flags & (REQ_F_LINK | REQ_F_HARDLINK)) {
+			if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK))
+				io_disarm_next(req);
+			if (req->link) {
+				io_req_task_queue(req->link);
+				req->link = NULL;
+			}
+		}
 		io_dismantle_req(req);
 		io_put_task(req->task, 1);
 		list_add(&req->compl.list, &cs->locked_free_list);
 		cs->locked_free_nr++;
 	} else
 		req = NULL;
+	io_commit_cqring(ctx);
 	spin_unlock_irqrestore(&ctx->completion_lock, flags);
-
 	io_cqring_ev_posted(ctx);
-	if (req) {
-		io_queue_next(req);
+
+	if (req)
 		percpu_ref_put(&ctx->refs);
-	}
 }
 
 static void io_req_complete_state(struct io_kiocb *req, long res,
@@ -1648,6 +1663,10 @@ static void io_dismantle_req(struct io_kiocb *req)
 		io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE));
 	if (req->fixed_rsrc_refs)
 		percpu_ref_put(req->fixed_rsrc_refs);
+	if (req->work.creds) {
+		put_cred(req->work.creds);
+		req->work.creds = NULL;
+	}
 
 	if (req->flags & REQ_F_INFLIGHT) {
 		struct io_ring_ctx *ctx = req->ctx;
@@ -1690,15 +1709,11 @@ static inline void io_remove_next_linked(struct io_kiocb *req)
 	nxt->link = NULL;
 }
 
-static void io_kill_linked_timeout(struct io_kiocb *req)
+static bool io_kill_linked_timeout(struct io_kiocb *req)
+	__must_hold(&req->ctx->completion_lock)
 {
-	struct io_ring_ctx *ctx = req->ctx;
-	struct io_kiocb *link;
+	struct io_kiocb *link = req->link;
 	bool cancelled = false;
-	unsigned long flags;
-
-	spin_lock_irqsave(&ctx->completion_lock, flags);
-	link = req->link;
 
 	/*
 	 * Can happen if a linked timeout fired and link had been like
@@ -1713,50 +1728,48 @@ static void io_kill_linked_timeout(struct io_kiocb *req)
 		ret = hrtimer_try_to_cancel(&io->timer);
 		if (ret != -1) {
 			io_cqring_fill_event(link, -ECANCELED);
-			io_commit_cqring(ctx);
+			io_put_req_deferred(link, 1);
 			cancelled = true;
 		}
 	}
 	req->flags &= ~REQ_F_LINK_TIMEOUT;
-	spin_unlock_irqrestore(&ctx->completion_lock, flags);
-
-	if (cancelled) {
-		io_cqring_ev_posted(ctx);
-		io_put_req(link);
-	}
+	return cancelled;
 }
 
-
 static void io_fail_links(struct io_kiocb *req)
+	__must_hold(&req->ctx->completion_lock)
 {
-	struct io_kiocb *link, *nxt;
-	struct io_ring_ctx *ctx = req->ctx;
-	unsigned long flags;
+	struct io_kiocb *nxt, *link = req->link;
 
-	spin_lock_irqsave(&ctx->completion_lock, flags);
-	link = req->link;
 	req->link = NULL;
-
 	while (link) {
 		nxt = link->link;
 		link->link = NULL;
 
 		trace_io_uring_fail_link(req, link);
 		io_cqring_fill_event(link, -ECANCELED);
-
 		io_put_req_deferred(link, 2);
 		link = nxt;
 	}
-	io_commit_cqring(ctx);
-	spin_unlock_irqrestore(&ctx->completion_lock, flags);
+}
 
-	io_cqring_ev_posted(ctx);
+static bool io_disarm_next(struct io_kiocb *req)
+	__must_hold(&req->ctx->completion_lock)
+{
+	bool posted = false;
+
+	if (likely(req->flags & REQ_F_LINK_TIMEOUT))
+		posted = io_kill_linked_timeout(req);
+	if (unlikely(req->flags & REQ_F_FAIL_LINK)) {
+		posted |= (req->link != NULL);
+		io_fail_links(req);
+	}
+	return posted;
 }
 
 static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
 {
-	if (req->flags & REQ_F_LINK_TIMEOUT)
-		io_kill_linked_timeout(req);
+	struct io_kiocb *nxt;
 
 	/*
 	 * If LINK is set, we have dependent requests in this chain. If we
@@ -1764,14 +1777,22 @@ static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
 	 * dependencies to the next request. In case of failure, fail the rest
 	 * of the chain.
 	 */
-	if (likely(!(req->flags & REQ_F_FAIL_LINK))) {
-		struct io_kiocb *nxt = req->link;
+	if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL_LINK)) {
+		struct io_ring_ctx *ctx = req->ctx;
+		unsigned long flags;
+		bool posted;
 
-		req->link = NULL;
-		return nxt;
+		spin_lock_irqsave(&ctx->completion_lock, flags);
+		posted = io_disarm_next(req);
+		if (posted)
+			io_commit_cqring(req->ctx);
+		spin_unlock_irqrestore(&ctx->completion_lock, flags);
+		if (posted)
+			io_cqring_ev_posted(ctx);
 	}
-	io_fail_links(req);
-	return NULL;
+	nxt = req->link;
+	req->link = NULL;
+	return nxt;
 }
 
 static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
@@ -5559,22 +5580,30 @@ add:
 	return 0;
 }
 
+struct io_cancel_data {
+	struct io_ring_ctx *ctx;
+	u64 user_data;
+};
+
 static bool io_cancel_cb(struct io_wq_work *work, void *data)
 {
 	struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+	struct io_cancel_data *cd = data;
 
-	return req->user_data == (unsigned long) data;
+	return req->ctx == cd->ctx && req->user_data == cd->user_data;
 }
 
-static int io_async_cancel_one(struct io_uring_task *tctx, void *sqe_addr)
+static int io_async_cancel_one(struct io_uring_task *tctx, u64 user_data,
+			       struct io_ring_ctx *ctx)
 {
+	struct io_cancel_data data = { .ctx = ctx, .user_data = user_data, };
 	enum io_wq_cancel cancel_ret;
 	int ret = 0;
 
-	if (!tctx->io_wq)
+	if (!tctx || !tctx->io_wq)
 		return -ENOENT;
 
-	cancel_ret = io_wq_cancel_cb(tctx->io_wq, io_cancel_cb, sqe_addr, false);
+	cancel_ret = io_wq_cancel_cb(tctx->io_wq, io_cancel_cb, &data, false);
 	switch (cancel_ret) {
 	case IO_WQ_CANCEL_OK:
 		ret = 0;
@@ -5597,8 +5626,7 @@ static void io_async_find_and_cancel(struct io_ring_ctx *ctx,
 	unsigned long flags;
 	int ret;
 
-	ret = io_async_cancel_one(req->task->io_uring,
-					(void *) (unsigned long) sqe_addr);
+	ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
 	if (ret != -ENOENT) {
 		spin_lock_irqsave(&ctx->completion_lock, flags);
 		goto done;
@@ -5639,8 +5667,47 @@ static int io_async_cancel_prep(struct io_kiocb *req,
 static int io_async_cancel(struct io_kiocb *req, unsigned int issue_flags)
 {
 	struct io_ring_ctx *ctx = req->ctx;
+	u64 sqe_addr = req->cancel.addr;
+	struct io_tctx_node *node;
+	int ret;
+
+	/* tasks should wait for their io-wq threads, so safe w/o sync */
+	ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
+	spin_lock_irq(&ctx->completion_lock);
+	if (ret != -ENOENT)
+		goto done;
+	ret = io_timeout_cancel(ctx, sqe_addr);
+	if (ret != -ENOENT)
+		goto done;
+	ret = io_poll_cancel(ctx, sqe_addr);
+	if (ret != -ENOENT)
+		goto done;
+	spin_unlock_irq(&ctx->completion_lock);
+
+	/* slow path, try all io-wq's */
+	io_ring_submit_lock(ctx, !(issue_flags & IO_URING_F_NONBLOCK));
+	ret = -ENOENT;
+	list_for_each_entry(node, &ctx->tctx_list, ctx_node) {
+		struct io_uring_task *tctx = node->task->io_uring;
+
+		if (!tctx || !tctx->io_wq)
+			continue;
+		ret = io_async_cancel_one(tctx, req->cancel.addr, ctx);
+		if (ret != -ENOENT)
+			break;
+	}
+	io_ring_submit_unlock(ctx, !(issue_flags & IO_URING_F_NONBLOCK));
 
-	io_async_find_and_cancel(ctx, req, req->cancel.addr, 0);
+	spin_lock_irq(&ctx->completion_lock);
+done:
+	io_cqring_fill_event(req, ret);
+	io_commit_cqring(ctx);
+	spin_unlock_irq(&ctx->completion_lock);
+	io_cqring_ev_posted(ctx);
+
+	if (ret < 0)
+		req_set_fail_links(req);
+	io_put_req(req);
 	return 0;
 }
 
@@ -5916,18 +5983,8 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
 	const struct cred *creds = NULL;
 	int ret;
 
-	if (req->work.personality) {
-		const struct cred *new_creds;
-
-		if (!(issue_flags & IO_URING_F_NONBLOCK))
-			mutex_lock(&ctx->uring_lock);
-		new_creds = idr_find(&ctx->personality_idr, req->work.personality);
-		if (!(issue_flags & IO_URING_F_NONBLOCK))
-			mutex_unlock(&ctx->uring_lock);
-		if (!new_creds)
-			return -EINVAL;
-		creds = override_creds(new_creds);
-	}
+	if (req->work.creds && req->work.creds != current_cred())
+		creds = override_creds(req->work.creds);
 
 	switch (req->opcode) {
 	case IORING_OP_NOP:
@@ -6291,7 +6348,7 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 {
 	struct io_submit_state *state;
 	unsigned int sqe_flags;
-	int ret = 0;
+	int personality, ret = 0;
 
 	req->opcode = READ_ONCE(sqe->opcode);
 	/* same numerical values with corresponding REQ_F_*, safe to copy */
@@ -6306,6 +6363,9 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 	refcount_set(&req->refs, 2);
 	req->task = current;
 	req->result = 0;
+	req->work.list.next = NULL;
+	req->work.creds = NULL;
+	req->work.flags = 0;
 
 	/* enforce forwards compatibility on users */
 	if (unlikely(sqe_flags & ~SQE_VALID_FLAGS)) {
@@ -6323,9 +6383,13 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 	    !io_op_defs[req->opcode].buffer_select)
 		return -EOPNOTSUPP;
 
-	req->work.list.next = NULL;
-	req->work.flags = 0;
-	req->work.personality = READ_ONCE(sqe->personality);
+	personality = READ_ONCE(sqe->personality);
+	if (personality) {
+		req->work.creds = xa_load(&ctx->personalities, personality);
+		if (!req->work.creds)
+			return -EINVAL;
+		get_cred(req->work.creds);
+	}
 	state = &ctx->submit_state;
 
 	/*
@@ -6587,7 +6651,8 @@ static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
 		if (!list_empty(&ctx->iopoll_list))
 			io_do_iopoll(ctx, &nr_events, 0);
 
-		if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)))
+		if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)) &&
+		    !(ctx->flags & IORING_SETUP_R_DISABLED))
 			ret = io_submit_sqes(ctx, to_submit);
 		mutex_unlock(&ctx->uring_lock);
 	}
@@ -6611,58 +6676,6 @@ static void io_sqd_update_thread_idle(struct io_sq_data *sqd)
 	sqd->sq_thread_idle = sq_thread_idle;
 }
 
-static void io_sqd_init_new(struct io_sq_data *sqd)
-{
-	struct io_ring_ctx *ctx;
-
-	while (!list_empty(&sqd->ctx_new_list)) {
-		ctx = list_first_entry(&sqd->ctx_new_list, struct io_ring_ctx, sqd_list);
-		list_move_tail(&ctx->sqd_list, &sqd->ctx_list);
-		complete(&ctx->sq_thread_comp);
-	}
-
-	io_sqd_update_thread_idle(sqd);
-}
-
-static bool io_sq_thread_should_stop(struct io_sq_data *sqd)
-{
-	return test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-}
-
-static bool io_sq_thread_should_park(struct io_sq_data *sqd)
-{
-	return test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-}
-
-static void io_sq_thread_parkme(struct io_sq_data *sqd)
-{
-	for (;;) {
-		/*
-		 * TASK_PARKED is a special state; we must serialize against
-		 * possible pending wakeups to avoid store-store collisions on
-		 * task->state.
-		 *
-		 * Such a collision might possibly result in the task state
-		 * changin from TASK_PARKED and us failing the
-		 * wait_task_inactive() in kthread_park().
-		 */
-		set_special_state(TASK_PARKED);
-		if (!test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state))
-			break;
-
-		/*
-		 * Thread is going to call schedule(), do not preempt it,
-		 * or the caller of kthread_park() may spend more time in
-		 * wait_task_inactive().
-		 */
-		preempt_disable();
-		complete(&sqd->parked);
-		schedule_preempt_disabled();
-		preempt_enable();
-	}
-	__set_current_state(TASK_RUNNING);
-}
-
 static int io_sq_thread(void *data)
 {
 	struct io_sq_data *sqd = data;
@@ -6681,31 +6694,32 @@ static int io_sq_thread(void *data)
 		set_cpus_allowed_ptr(current, cpu_online_mask);
 	current->flags |= PF_NO_SETAFFINITY;
 
-	wait_for_completion(&sqd->startup);
+	down_read(&sqd->rw_lock);
 
-	while (!io_sq_thread_should_stop(sqd)) {
+	while (!test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state)) {
 		int ret;
 		bool cap_entries, sqt_spin, needs_sched;
 
-		/*
-		 * Any changes to the sqd lists are synchronized through the
-		 * thread parking. This synchronizes the thread vs users,
-		 * the users are synchronized on the sqd->ctx_lock.
-		 */
-		if (io_sq_thread_should_park(sqd)) {
-			io_sq_thread_parkme(sqd);
-			continue;
-		}
-		if (unlikely(!list_empty(&sqd->ctx_new_list))) {
-			io_sqd_init_new(sqd);
+		if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state)) {
+			up_read(&sqd->rw_lock);
+			cond_resched();
+			down_read(&sqd->rw_lock);
+			io_run_task_work();
 			timeout = jiffies + sqd->sq_thread_idle;
+			continue;
 		}
 		if (fatal_signal_pending(current))
 			break;
 		sqt_spin = false;
 		cap_entries = !list_is_singular(&sqd->ctx_list);
 		list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+			const struct cred *creds = NULL;
+
+			if (ctx->sq_creds != current_cred())
+				creds = override_creds(ctx->sq_creds);
 			ret = __io_sq_thread(ctx, cap_entries);
+			if (creds)
+				revert_creds(creds);
 			if (!sqt_spin && (ret > 0 || !list_empty(&ctx->iopoll_list)))
 				sqt_spin = true;
 		}
@@ -6732,12 +6746,13 @@ static int io_sq_thread(void *data)
 			}
 		}
 
-		if (needs_sched && !io_sq_thread_should_park(sqd)) {
+		if (needs_sched && !test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state)) {
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 				io_ring_set_wakeup_flag(ctx);
 
+			up_read(&sqd->rw_lock);
 			schedule();
-			try_to_freeze();
+			down_read(&sqd->rw_lock);
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 				io_ring_clear_wakeup_flag(ctx);
 		}
@@ -6745,32 +6760,23 @@ static int io_sq_thread(void *data)
 		finish_wait(&sqd->wait, &wait);
 		timeout = jiffies + sqd->sq_thread_idle;
 	}
-
-	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-		io_uring_cancel_sqpoll(ctx);
-
-	io_run_task_work();
-
+	up_read(&sqd->rw_lock);
+	down_write(&sqd->rw_lock);
 	/*
-	 * Ensure that we park properly if racing with someone trying to park
-	 * while we're exiting. If we fail to grab the lock, check park and
-	 * park if necessary. The ordering with the park bit and the lock
-	 * ensures that we catch this reliably.
+	 * someone may have parked and added a cancellation task_work, run
+	 * it first because we don't want it in io_uring_cancel_sqpoll()
 	 */
-	if (!mutex_trylock(&sqd->lock)) {
-		if (io_sq_thread_should_park(sqd))
-			io_sq_thread_parkme(sqd);
-		mutex_lock(&sqd->lock);
-	}
+	io_run_task_work();
 
+	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+		io_uring_cancel_sqpoll(ctx);
 	sqd->thread = NULL;
-	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
-		ctx->sqo_exec = 1;
+	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 		io_ring_set_wakeup_flag(ctx);
-	}
+	up_write(&sqd->rw_lock);
 
+	io_run_task_work();
 	complete(&sqd->exited);
-	mutex_unlock(&sqd->lock);
 	do_exit(0);
 }
 
@@ -7069,44 +7075,37 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 }
 
 static void io_sq_thread_unpark(struct io_sq_data *sqd)
-	__releases(&sqd->lock)
+	__releases(&sqd->rw_lock)
 {
-	if (sqd->thread == current)
-		return;
+	WARN_ON_ONCE(sqd->thread == current);
+
 	clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	if (sqd->thread)
-		wake_up_state(sqd->thread, TASK_PARKED);
-	mutex_unlock(&sqd->lock);
+	up_write(&sqd->rw_lock);
 }
 
 static void io_sq_thread_park(struct io_sq_data *sqd)
-	__acquires(&sqd->lock)
+	__acquires(&sqd->rw_lock)
 {
-	if (sqd->thread == current)
-		return;
+	WARN_ON_ONCE(sqd->thread == current);
+
 	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	mutex_lock(&sqd->lock);
-	if (sqd->thread) {
+	down_write(&sqd->rw_lock);
+	/* set again for consistency, in case concurrent parks are happening */
+	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+	if (sqd->thread)
 		wake_up_process(sqd->thread);
-		wait_for_completion(&sqd->parked);
-	}
 }
 
 static void io_sq_thread_stop(struct io_sq_data *sqd)
 {
-	if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
-		return;
-	mutex_lock(&sqd->lock);
-	if (sqd->thread) {
-		set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-		WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state));
+	WARN_ON_ONCE(sqd->thread == current);
+
+	down_write(&sqd->rw_lock);
+	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+	if (sqd->thread)
 		wake_up_process(sqd->thread);
-		mutex_unlock(&sqd->lock);
-		wait_for_completion(&sqd->exited);
-		WARN_ON_ONCE(sqd->thread);
-	} else {
-		mutex_unlock(&sqd->lock);
-	}
+	up_write(&sqd->rw_lock);
+	wait_for_completion(&sqd->exited);
 }
 
 static void io_put_sq_data(struct io_sq_data *sqd)
@@ -7122,22 +7121,15 @@ static void io_sq_thread_finish(struct io_ring_ctx *ctx)
 	struct io_sq_data *sqd = ctx->sq_data;
 
 	if (sqd) {
-		complete(&sqd->startup);
-		if (sqd->thread) {
-			wait_for_completion(&ctx->sq_thread_comp);
-			io_sq_thread_park(sqd);
-		}
-
-		mutex_lock(&sqd->ctx_lock);
-		list_del(&ctx->sqd_list);
+		io_sq_thread_park(sqd);
+		list_del_init(&ctx->sqd_list);
 		io_sqd_update_thread_idle(sqd);
-		mutex_unlock(&sqd->ctx_lock);
-
-		if (sqd->thread)
-			io_sq_thread_unpark(sqd);
+		io_sq_thread_unpark(sqd);
 
 		io_put_sq_data(sqd);
 		ctx->sq_data = NULL;
+		if (ctx->sq_creds)
+			put_cred(ctx->sq_creds);
 	}
 }
 
@@ -7161,18 +7153,32 @@ static struct io_sq_data *io_attach_sq_data(struct io_uring_params *p)
 		fdput(f);
 		return ERR_PTR(-EINVAL);
 	}
+	if (sqd->task_tgid != current->tgid) {
+		fdput(f);
+		return ERR_PTR(-EPERM);
+	}
 
 	refcount_inc(&sqd->refs);
 	fdput(f);
 	return sqd;
 }
 
-static struct io_sq_data *io_get_sq_data(struct io_uring_params *p)
+static struct io_sq_data *io_get_sq_data(struct io_uring_params *p,
+					 bool *attached)
 {
 	struct io_sq_data *sqd;
 
-	if (p->flags & IORING_SETUP_ATTACH_WQ)
-		return io_attach_sq_data(p);
+	*attached = false;
+	if (p->flags & IORING_SETUP_ATTACH_WQ) {
+		sqd = io_attach_sq_data(p);
+		if (!IS_ERR(sqd)) {
+			*attached = true;
+			return sqd;
+		}
+		/* fall through for EPERM case, setup new sqd/task */
+		if (PTR_ERR(sqd) != -EPERM)
+			return sqd;
+	}
 
 	sqd = kzalloc(sizeof(*sqd), GFP_KERNEL);
 	if (!sqd)
@@ -7180,12 +7186,8 @@ static struct io_sq_data *io_get_sq_data(struct io_uring_params *p)
 
 	refcount_set(&sqd->refs, 1);
 	INIT_LIST_HEAD(&sqd->ctx_list);
-	INIT_LIST_HEAD(&sqd->ctx_new_list);
-	mutex_init(&sqd->ctx_lock);
-	mutex_init(&sqd->lock);
+	init_rwsem(&sqd->rw_lock);
 	init_waitqueue_head(&sqd->wait);
-	init_completion(&sqd->startup);
-	init_completion(&sqd->parked);
 	init_completion(&sqd->exited);
 	return sqd;
 }
@@ -7802,7 +7804,6 @@ static int io_uring_alloc_task_context(struct task_struct *task,
 	init_waitqueue_head(&tctx->wait);
 	tctx->last = NULL;
 	atomic_set(&tctx->in_idle, 0);
-	tctx->sqpoll = false;
 	task->io_uring = tctx;
 	spin_lock_init(&tctx->task_lock);
 	INIT_WQ_LIST(&tctx->task_list);
@@ -7823,26 +7824,6 @@ void __io_uring_free(struct task_struct *tsk)
 	tsk->io_uring = NULL;
 }
 
-static int io_sq_thread_fork(struct io_sq_data *sqd, struct io_ring_ctx *ctx)
-{
-	struct task_struct *tsk;
-	int ret;
-
-	clear_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-	reinit_completion(&sqd->parked);
-	ctx->sqo_exec = 0;
-	sqd->task_pid = current->pid;
-	tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
-	if (IS_ERR(tsk))
-		return PTR_ERR(tsk);
-	ret = io_uring_alloc_task_context(tsk, ctx);
-	if (ret)
-		set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-	sqd->thread = tsk;
-	wake_up_new_task(tsk);
-	return ret;
-}
-
 static int io_sq_offload_create(struct io_ring_ctx *ctx,
 				struct io_uring_params *p)
 {
@@ -7865,39 +7846,51 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 	if (ctx->flags & IORING_SETUP_SQPOLL) {
 		struct task_struct *tsk;
 		struct io_sq_data *sqd;
+		bool attached;
 
 		ret = -EPERM;
 		if (!capable(CAP_SYS_ADMIN) && !capable(CAP_SYS_NICE))
 			goto err;
 
-		sqd = io_get_sq_data(p);
+		sqd = io_get_sq_data(p, &attached);
 		if (IS_ERR(sqd)) {
 			ret = PTR_ERR(sqd);
 			goto err;
 		}
 
+		ctx->sq_creds = get_current_cred();
 		ctx->sq_data = sqd;
-		io_sq_thread_park(sqd);
-		mutex_lock(&sqd->ctx_lock);
-		list_add(&ctx->sqd_list, &sqd->ctx_new_list);
-		mutex_unlock(&sqd->ctx_lock);
-		io_sq_thread_unpark(sqd);
-
 		ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
 		if (!ctx->sq_thread_idle)
 			ctx->sq_thread_idle = HZ;
 
-		if (sqd->thread)
+		ret = 0;
+		io_sq_thread_park(sqd);
+		/* don't attach to a dying SQPOLL thread, would be racy */
+		if (attached && !sqd->thread) {
+			ret = -ENXIO;
+		} else {
+			list_add(&ctx->sqd_list, &sqd->ctx_list);
+			io_sqd_update_thread_idle(sqd);
+		}
+		io_sq_thread_unpark(sqd);
+
+		if (ret < 0) {
+			io_put_sq_data(sqd);
+			ctx->sq_data = NULL;
+			return ret;
+		} else if (attached) {
 			return 0;
+		}
 
 		if (p->flags & IORING_SETUP_SQ_AFF) {
 			int cpu = p->sq_thread_cpu;
 
 			ret = -EINVAL;
 			if (cpu >= nr_cpu_ids)
-				goto err;
+				goto err_sqpoll;
 			if (!cpu_online(cpu))
-				goto err;
+				goto err_sqpoll;
 
 			sqd->sq_cpu = cpu;
 		} else {
@@ -7905,15 +7898,15 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 		}
 
 		sqd->task_pid = current->pid;
+		sqd->task_tgid = current->tgid;
 		tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
 		if (IS_ERR(tsk)) {
 			ret = PTR_ERR(tsk);
-			goto err;
+			goto err_sqpoll;
 		}
-		ret = io_uring_alloc_task_context(tsk, ctx);
-		if (ret)
-			set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+
 		sqd->thread = tsk;
+		ret = io_uring_alloc_task_context(tsk, ctx);
 		wake_up_new_task(tsk);
 		if (ret)
 			goto err;
@@ -7927,15 +7920,9 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 err:
 	io_sq_thread_finish(ctx);
 	return ret;
-}
-
-static void io_sq_offload_start(struct io_ring_ctx *ctx)
-{
-	struct io_sq_data *sqd = ctx->sq_data;
-
-	ctx->flags &= ~IORING_SETUP_R_DISABLED;
-	if (ctx->flags & IORING_SETUP_SQPOLL)
-		complete(&sqd->startup);
+err_sqpoll:
+	complete(&ctx->sq_data->exited);
+	goto err;
 }
 
 static inline void __io_unaccount_mem(struct user_struct *user,
@@ -8418,7 +8405,6 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 	mutex_unlock(&ctx->uring_lock);
 	io_eventfd_unregister(ctx);
 	io_destroy_buffers(ctx);
-	idr_destroy(&ctx->personality_idr);
 
 #if defined(CONFIG_UNIX)
 	if (ctx->ring_sock) {
@@ -8483,7 +8469,7 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 {
 	const struct cred *creds;
 
-	creds = idr_remove(&ctx->personality_idr, id);
+	creds = xa_erase(&ctx->personalities, id);
 	if (creds) {
 		put_cred(creds);
 		return 0;
@@ -8492,14 +8478,6 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 	return -EINVAL;
 }
 
-static int io_remove_personalities(int id, void *p, void *data)
-{
-	struct io_ring_ctx *ctx = data;
-
-	io_unregister_personality(ctx, id);
-	return 0;
-}
-
 static bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
 {
 	struct callback_head *work, *next;
@@ -8522,10 +8500,34 @@ static bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
 	return executed;
 }
 
+struct io_tctx_exit {
+	struct callback_head		task_work;
+	struct completion		completion;
+	struct io_ring_ctx		*ctx;
+};
+
+static void io_tctx_exit_cb(struct callback_head *cb)
+{
+	struct io_uring_task *tctx = current->io_uring;
+	struct io_tctx_exit *work;
+
+	work = container_of(cb, struct io_tctx_exit, task_work);
+	/*
+	 * When @in_idle, we're in cancellation and it's racy to remove the
+	 * node. It'll be removed by the end of cancellation, just ignore it.
+	 */
+	if (!atomic_read(&tctx->in_idle))
+		io_uring_del_task_file((unsigned long)work->ctx);
+	complete(&work->completion);
+}
+
 static void io_ring_exit_work(struct work_struct *work)
 {
-	struct io_ring_ctx *ctx = container_of(work, struct io_ring_ctx,
-					       exit_work);
+	struct io_ring_ctx *ctx = container_of(work, struct io_ring_ctx, exit_work);
+	unsigned long timeout = jiffies + HZ * 60 * 5;
+	struct io_tctx_exit exit;
+	struct io_tctx_node *node;
+	int ret;
 
 	/*
 	 * If we're doing polled IO and end up having requests being
@@ -8535,19 +8537,47 @@ static void io_ring_exit_work(struct work_struct *work)
 	 */
 	do {
 		io_uring_try_cancel_requests(ctx, NULL, NULL);
+
+		WARN_ON_ONCE(time_after(jiffies, timeout));
 	} while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20));
+
+	mutex_lock(&ctx->uring_lock);
+	while (!list_empty(&ctx->tctx_list)) {
+		WARN_ON_ONCE(time_after(jiffies, timeout));
+
+		node = list_first_entry(&ctx->tctx_list, struct io_tctx_node,
+					ctx_node);
+		exit.ctx = ctx;
+		init_completion(&exit.completion);
+		init_task_work(&exit.task_work, io_tctx_exit_cb);
+		ret = task_work_add(node->task, &exit.task_work, TWA_SIGNAL);
+		if (WARN_ON_ONCE(ret))
+			continue;
+		wake_up_process(node->task);
+
+		mutex_unlock(&ctx->uring_lock);
+		wait_for_completion(&exit.completion);
+		cond_resched();
+		mutex_lock(&ctx->uring_lock);
+	}
+	mutex_unlock(&ctx->uring_lock);
+
 	io_ring_ctx_free(ctx);
 }
 
 static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
 {
+	unsigned long index;
+	struct creds *creds;
+
 	mutex_lock(&ctx->uring_lock);
 	percpu_ref_kill(&ctx->refs);
 	/* if force is set, the ring is going away. always drop after that */
 	ctx->cq_overflow_flushed = 1;
 	if (ctx->rings)
 		__io_cqring_overflow_flush(ctx, true, NULL, NULL);
-	idr_for_each(&ctx->personality_idr, io_remove_personalities, ctx);
+	xa_for_each(&ctx->personalities, index, creds)
+		io_unregister_personality(ctx, index);
 	mutex_unlock(&ctx->uring_lock);
 
 	io_kill_timeouts(ctx, NULL, NULL);
@@ -8600,11 +8630,11 @@ static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
 	return ret;
 }
 
-static void io_cancel_defer_files(struct io_ring_ctx *ctx,
+static bool io_cancel_defer_files(struct io_ring_ctx *ctx,
 				  struct task_struct *task,
 				  struct files_struct *files)
 {
-	struct io_defer_entry *de = NULL;
+	struct io_defer_entry *de;
 	LIST_HEAD(list);
 
 	spin_lock_irq(&ctx->completion_lock);
@@ -8615,6 +8645,8 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
 		}
 	}
 	spin_unlock_irq(&ctx->completion_lock);
+	if (list_empty(&list))
+		return false;
 
 	while (!list_empty(&list)) {
 		de = list_first_entry(&list, struct io_defer_entry, list);
@@ -8624,6 +8656,38 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
 		io_req_complete(de->req, -ECANCELED);
 		kfree(de);
 	}
+	return true;
+}
+
+static bool io_cancel_ctx_cb(struct io_wq_work *work, void *data)
+{
+	struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+
+	return req->ctx == data;
+}
+
+static bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx)
+{
+	struct io_tctx_node *node;
+	enum io_wq_cancel cret;
+	bool ret = false;
+
+	mutex_lock(&ctx->uring_lock);
+	list_for_each_entry(node, &ctx->tctx_list, ctx_node) {
+		struct io_uring_task *tctx = node->task->io_uring;
+
+		/*
+		 * io_wq will stay alive while we hold uring_lock, because it's
+		 * killed after ctx nodes, which requires to take the lock.
+		 */
+		if (!tctx || !tctx->io_wq)
+			continue;
+		cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_ctx_cb, ctx, true);
+		ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
+	}
+	mutex_unlock(&ctx->uring_lock);
+
+	return ret;
 }
 
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
@@ -8631,27 +8695,34 @@ static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 					 struct files_struct *files)
 {
 	struct io_task_cancel cancel = { .task = task, .files = files, };
-	struct task_struct *tctx_task = task ?: current;
-	struct io_uring_task *tctx = tctx_task->io_uring;
+	struct io_uring_task *tctx = task ? task->io_uring : NULL;
 
 	while (1) {
 		enum io_wq_cancel cret;
 		bool ret = false;
 
-		if (tctx && tctx->io_wq) {
+		if (!task) {
+			ret |= io_uring_try_cancel_iowq(ctx);
+		} else if (tctx && tctx->io_wq) {
+			/*
+			 * Cancels requests of all rings, not only @ctx, but
+			 * it's fine as the task is in exit/exec.
+			 */
 			cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_task_cb,
 					       &cancel, true);
 			ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
 		}
 
 		/* SQPOLL thread does its own polling */
-		if (!(ctx->flags & IORING_SETUP_SQPOLL) && !files) {
+		if ((!(ctx->flags & IORING_SETUP_SQPOLL) && !files) ||
+		    (ctx->sq_data && ctx->sq_data->thread == current)) {
 			while (!list_empty_careful(&ctx->iopoll_list)) {
 				io_iopoll_try_reap_events(ctx);
 				ret = true;
 			}
 		}
 
+		ret |= io_cancel_defer_files(ctx, task, files);
 		ret |= io_poll_remove_all(ctx, task, files);
 		ret |= io_kill_timeouts(ctx, task, files);
 		ret |= io_run_task_work();
@@ -8691,58 +8762,21 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
 
 		io_uring_try_cancel_requests(ctx, task, files);
 
-		if (ctx->sq_data)
-			io_sq_thread_unpark(ctx->sq_data);
 		prepare_to_wait(&task->io_uring->wait, &wait,
 				TASK_UNINTERRUPTIBLE);
 		if (inflight == io_uring_count_inflight(ctx, task, files))
 			schedule();
 		finish_wait(&task->io_uring->wait, &wait);
-		if (ctx->sq_data)
-			io_sq_thread_park(ctx->sq_data);
-	}
-}
-
-/*
- * We need to iteratively cancel requests, in case a request has dependent
- * hard links. These persist even for failure of cancelations, hence keep
- * looping until none are found.
- */
-static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
-					  struct files_struct *files)
-{
-	struct task_struct *task = current;
-
-	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
-		/* never started, nothing to cancel */
-		if (ctx->flags & IORING_SETUP_R_DISABLED) {
-			io_sq_offload_start(ctx);
-			return;
-		}
-		io_sq_thread_park(ctx->sq_data);
-		task = ctx->sq_data->thread;
-		if (task)
-			atomic_inc(&task->io_uring->in_idle);
 	}
-
-	io_cancel_defer_files(ctx, task, files);
-
-	io_uring_cancel_files(ctx, task, files);
-	if (!files)
-		io_uring_try_cancel_requests(ctx, task, NULL);
-
-	if (task)
-		atomic_dec(&task->io_uring->in_idle);
-	if (ctx->sq_data)
-		io_sq_thread_unpark(ctx->sq_data);
 }
 
 /*
  * Note that this task has used io_uring. We use it for cancelation purposes.
  */
-static int io_uring_add_task_file(struct io_ring_ctx *ctx, struct file *file)
+static int io_uring_add_task_file(struct io_ring_ctx *ctx)
 {
 	struct io_uring_task *tctx = current->io_uring;
+	struct io_tctx_node *node;
 	int ret;
 
 	if (unlikely(!tctx)) {
@@ -8751,102 +8785,151 @@ static int io_uring_add_task_file(struct io_ring_ctx *ctx, struct file *file)
 			return ret;
 		tctx = current->io_uring;
 	}
-	if (tctx->last != file) {
-		void *old = xa_load(&tctx->xa, (unsigned long)file);
+	if (tctx->last != ctx) {
+		void *old = xa_load(&tctx->xa, (unsigned long)ctx);
 
 		if (!old) {
-			get_file(file);
-			ret = xa_err(xa_store(&tctx->xa, (unsigned long)file,
-						file, GFP_KERNEL));
+			node = kmalloc(sizeof(*node), GFP_KERNEL);
+			if (!node)
+				return -ENOMEM;
+			node->ctx = ctx;
+			node->task = current;
+
+			ret = xa_err(xa_store(&tctx->xa, (unsigned long)ctx,
+						node, GFP_KERNEL));
 			if (ret) {
-				fput(file);
+				kfree(node);
 				return ret;
 			}
+
+			mutex_lock(&ctx->uring_lock);
+			list_add(&node->ctx_node, &ctx->tctx_list);
+			mutex_unlock(&ctx->uring_lock);
 		}
-		tctx->last = file;
+		tctx->last = ctx;
 	}
-
-	/*
-	 * This is race safe in that the task itself is doing this, hence it
-	 * cannot be going through the exit/cancel paths at the same time.
-	 * This cannot be modified while exit/cancel is running.
-	 */
-	if (!tctx->sqpoll && (ctx->flags & IORING_SETUP_SQPOLL))
-		tctx->sqpoll = true;
-
 	return 0;
 }
 
 /*
  * Remove this io_uring_file -> task mapping.
  */
-static void io_uring_del_task_file(struct file *file)
+static void io_uring_del_task_file(unsigned long index)
 {
 	struct io_uring_task *tctx = current->io_uring;
+	struct io_tctx_node *node;
 
-	if (tctx->last == file)
+	if (!tctx)
+		return;
+	node = xa_erase(&tctx->xa, index);
+	if (!node)
+		return;
+
+	WARN_ON_ONCE(current != node->task);
+	WARN_ON_ONCE(list_empty(&node->ctx_node));
+
+	mutex_lock(&node->ctx->uring_lock);
+	list_del(&node->ctx_node);
+	mutex_unlock(&node->ctx->uring_lock);
+
+	if (tctx->last == node->ctx)
 		tctx->last = NULL;
-	file = xa_erase(&tctx->xa, (unsigned long)file);
-	if (file)
-		fput(file);
+	kfree(node);
 }
 
 static void io_uring_clean_tctx(struct io_uring_task *tctx)
 {
-	struct file *file;
+	struct io_tctx_node *node;
 	unsigned long index;
 
-	xa_for_each(&tctx->xa, index, file)
-		io_uring_del_task_file(file);
+	xa_for_each(&tctx->xa, index, node)
+		io_uring_del_task_file(index);
 	if (tctx->io_wq) {
 		io_wq_put_and_exit(tctx->io_wq);
 		tctx->io_wq = NULL;
 	}
 }
 
+static s64 tctx_inflight(struct io_uring_task *tctx)
+{
+	return percpu_counter_sum(&tctx->inflight);
+}
+
+static void io_sqpoll_cancel_cb(struct callback_head *cb)
+{
+	struct io_tctx_exit *work = container_of(cb, struct io_tctx_exit, task_work);
+	struct io_ring_ctx *ctx = work->ctx;
+	struct io_sq_data *sqd = ctx->sq_data;
+
+	if (sqd->thread)
+		io_uring_cancel_sqpoll(ctx);
+	complete(&work->completion);
+}
+
+static void io_sqpoll_cancel_sync(struct io_ring_ctx *ctx)
+{
+	struct io_sq_data *sqd = ctx->sq_data;
+	struct io_tctx_exit work = { .ctx = ctx, };
+	struct task_struct *task;
+
+	io_sq_thread_park(sqd);
+	list_del_init(&ctx->sqd_list);
+	io_sqd_update_thread_idle(sqd);
+	task = sqd->thread;
+	if (task) {
+		init_completion(&work.completion);
+		init_task_work(&work.task_work, io_sqpoll_cancel_cb);
+		WARN_ON_ONCE(task_work_add(task, &work.task_work, TWA_SIGNAL));
+		wake_up_process(task);
+	}
+	io_sq_thread_unpark(sqd);
+
+	if (task)
+		wait_for_completion(&work.completion);
+}
+
 void __io_uring_files_cancel(struct files_struct *files)
 {
 	struct io_uring_task *tctx = current->io_uring;
-	struct file *file;
+	struct io_tctx_node *node;
 	unsigned long index;
 
 	/* make sure overflow events are dropped */
 	atomic_inc(&tctx->in_idle);
-	xa_for_each(&tctx->xa, index, file)
-		io_uring_cancel_task_requests(file->private_data, files);
+	xa_for_each(&tctx->xa, index, node) {
+		struct io_ring_ctx *ctx = node->ctx;
+
+		if (ctx->sq_data) {
+			io_sqpoll_cancel_sync(ctx);
+			continue;
+		}
+		io_uring_cancel_files(ctx, current, files);
+		if (!files)
+			io_uring_try_cancel_requests(ctx, current, NULL);
+	}
 	atomic_dec(&tctx->in_idle);
 
 	if (files)
 		io_uring_clean_tctx(tctx);
 }
 
-static s64 tctx_inflight(struct io_uring_task *tctx)
-{
-	return percpu_counter_sum(&tctx->inflight);
-}
-
+/* should only be called by SQPOLL task */
 static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 {
 	struct io_sq_data *sqd = ctx->sq_data;
-	struct io_uring_task *tctx;
+	struct io_uring_task *tctx = current->io_uring;
 	s64 inflight;
 	DEFINE_WAIT(wait);
 
-	if (!sqd)
-		return;
-	io_sq_thread_park(sqd);
-	if (!sqd->thread || !sqd->thread->io_uring) {
-		io_sq_thread_unpark(sqd);
-		return;
-	}
-	tctx = ctx->sq_data->thread->io_uring;
+	WARN_ON_ONCE(!sqd || ctx->sq_data->thread != current);
+
 	atomic_inc(&tctx->in_idle);
 	do {
 		/* read completions before cancelations */
 		inflight = tctx_inflight(tctx);
 		if (!inflight)
 			break;
-		io_uring_cancel_task_requests(ctx, NULL);
+		io_uring_try_cancel_requests(ctx, current, NULL);
 
 		prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
 		/*
@@ -8859,7 +8942,6 @@ static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 		finish_wait(&tctx->wait, &wait);
 	} while (1);
 	atomic_dec(&tctx->in_idle);
-	io_sq_thread_unpark(sqd);
 }
 
 /*
@@ -8874,15 +8956,6 @@ void __io_uring_task_cancel(void)
 
 	/* make sure overflow events are dropped */
 	atomic_inc(&tctx->in_idle);
-
-	if (tctx->sqpoll) {
-		struct file *file;
-		unsigned long index;
-
-		xa_for_each(&tctx->xa, index, file)
-			io_uring_cancel_sqpoll(file->private_data);
-	}
-
 	do {
 		/* read completions before cancelations */
 		inflight = tctx_inflight(tctx);
@@ -8981,7 +9054,6 @@ static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
 
 static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 {
-	int ret = 0;
 	DEFINE_WAIT(wait);
 
 	do {
@@ -8995,7 +9067,7 @@ static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 	} while (!signal_pending(current));
 
 	finish_wait(&ctx->sqo_sq_wait, &wait);
-	return ret;
+	return 0;
 }
 
 static int io_get_ext_arg(unsigned flags, const void __user *argp, size_t *argsz,
@@ -9069,13 +9141,10 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
 	if (ctx->flags & IORING_SETUP_SQPOLL) {
 		io_cqring_overflow_flush(ctx, false, NULL, NULL);
 
-		if (unlikely(ctx->sqo_exec)) {
-			ret = io_sq_thread_fork(ctx->sq_data, ctx);
-			if (ret)
-				goto out;
-			ctx->sqo_exec = 0;
-		}
 		ret = -EOWNERDEAD;
+		if (unlikely(ctx->sq_data->thread == NULL)) {
+			goto out;
+		}
 		if (flags & IORING_ENTER_SQ_WAKEUP)
 			wake_up(&ctx->sq_data->wait);
 		if (flags & IORING_ENTER_SQ_WAIT) {
@@ -9085,7 +9154,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
 		}
 		submitted = to_submit;
 	} else if (to_submit) {
-		ret = io_uring_add_task_file(ctx, f.file);
+		ret = io_uring_add_task_file(ctx);
 		if (unlikely(ret))
 			goto out;
 		mutex_lock(&ctx->uring_lock);
@@ -9127,10 +9196,9 @@ out_fput:
 }
 
 #ifdef CONFIG_PROC_FS
-static int io_uring_show_cred(int id, void *p, void *data)
+static int io_uring_show_cred(struct seq_file *m, unsigned int id,
+		const struct cred *cred)
 {
-	const struct cred *cred = p;
-	struct seq_file *m = data;
 	struct user_namespace *uns = seq_user_ns(m);
 	struct group_info *gi;
 	kernel_cap_t cap;
@@ -9198,9 +9266,13 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
 		seq_printf(m, "%5u: 0x%llx/%u\n", i, buf->ubuf,
 						(unsigned int) buf->len);
 	}
-	if (has_lock && !idr_is_empty(&ctx->personality_idr)) {
+	if (has_lock && !xa_empty(&ctx->personalities)) {
+		unsigned long index;
+		const struct cred *cred;
+
 		seq_printf(m, "Personalities:\n");
-		idr_for_each(&ctx->personality_idr, io_uring_show_cred, m);
+		xa_for_each(&ctx->personalities, index, cred)
+			io_uring_show_cred(m, index, cred);
 	}
 	seq_printf(m, "PollList:\n");
 	spin_lock_irq(&ctx->completion_lock);
@@ -9294,7 +9366,7 @@ static int io_uring_install_fd(struct io_ring_ctx *ctx, struct file *file)
 	if (fd < 0)
 		return fd;
 
-	ret = io_uring_add_task_file(ctx, file);
+	ret = io_uring_add_task_file(ctx);
 	if (ret) {
 		put_unused_fd(fd);
 		return ret;
@@ -9402,9 +9474,6 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
 	if (ret)
 		goto err;
 
-	if (!(p->flags & IORING_SETUP_R_DISABLED))
-		io_sq_offload_start(ctx);
-
 	memset(&p->sq_off, 0, sizeof(p->sq_off));
 	p->sq_off.head = offsetof(struct io_rings, sq.head);
 	p->sq_off.tail = offsetof(struct io_rings, sq.tail);
@@ -9532,14 +9601,16 @@ out:
 static int io_register_personality(struct io_ring_ctx *ctx)
 {
 	const struct cred *creds;
+	u32 id;
 	int ret;
 
 	creds = get_current_cred();
 
-	ret = idr_alloc_cyclic(&ctx->personality_idr, (void *) creds, 1,
-				USHRT_MAX, GFP_KERNEL);
-	if (ret < 0)
-		put_cred(creds);
+	ret = xa_alloc_cyclic(&ctx->personalities, &id, (void *)creds,
+			XA_LIMIT(0, USHRT_MAX), &ctx->pers_next, GFP_KERNEL);
+	if (!ret)
+		return id;
+	put_cred(creds);
 	return ret;
 }
 
@@ -9621,7 +9692,9 @@ static int io_register_enable_rings(struct io_ring_ctx *ctx)
 	if (ctx->restrictions.registered)
 		ctx->restricted = 1;
 
-	io_sq_offload_start(ctx);
+	ctx->flags &= ~IORING_SETUP_R_DISABLED;
+	if (ctx->sq_data && wq_has_sleeper(&ctx->sq_data->wait))
+		wake_up(&ctx->sq_data->wait);
 	return 0;
 }