summary refs log tree commit diff
path: root/fs/io_uring.c
diff options
context:
space:
mode:
Diffstat (limited to 'fs/io_uring.c')
-rw-r--r--fs/io_uring.c360
1 files changed, 248 insertions, 112 deletions
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 7e35283fc0b1..985a9e3f976d 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -262,6 +262,7 @@ struct io_ring_ctx {
 		unsigned int		drain_next: 1;
 		unsigned int		eventfd_async: 1;
 		unsigned int		restricted: 1;
+		unsigned int		sqo_dead: 1;
 
 		/*
 		 * Ring buffer of indices into array of io_uring_sqe, which is
@@ -353,6 +354,7 @@ struct io_ring_ctx {
 		unsigned		cq_entries;
 		unsigned		cq_mask;
 		atomic_t		cq_timeouts;
+		unsigned		cq_last_tm_flush;
 		unsigned long		cq_check_overflow;
 		struct wait_queue_head	cq_wait;
 		struct fasync_struct	*cq_fasync;
@@ -992,6 +994,13 @@ enum io_mem_account {
 	ACCT_PINNED,
 };
 
+static void __io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
+					    struct task_struct *task);
+
+static void destroy_fixed_file_ref_node(struct fixed_file_ref_node *ref_node);
+static struct fixed_file_ref_node *alloc_fixed_file_ref_node(
+			struct io_ring_ctx *ctx);
+
 static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
 			     struct io_comp_state *cs);
 static void io_cqring_fill_event(struct io_kiocb *req, long res);
@@ -1098,6 +1107,9 @@ static void io_sq_thread_drop_mm_files(void)
 
 static int __io_sq_thread_acquire_files(struct io_ring_ctx *ctx)
 {
+	if (current->flags & PF_EXITING)
+		return -EFAULT;
+
 	if (!current->files) {
 		struct files_struct *files;
 		struct nsproxy *nsproxy;
@@ -1125,6 +1137,8 @@ static int __io_sq_thread_acquire_mm(struct io_ring_ctx *ctx)
 {
 	struct mm_struct *mm;
 
+	if (current->flags & PF_EXITING)
+		return -EFAULT;
 	if (current->mm)
 		return 0;
 
@@ -1338,11 +1352,6 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
 
 	/* order cqe stores with ring update */
 	smp_store_release(&rings->cq.tail, ctx->cached_cq_tail);
-
-	if (wq_has_sleeper(&ctx->cq_wait)) {
-		wake_up_interruptible(&ctx->cq_wait);
-		kill_fasync(&ctx->cq_fasync, SIGIO, POLL_IN);
-	}
 }
 
 static void io_put_identity(struct io_uring_task *tctx, struct io_kiocb *req)
@@ -1501,6 +1510,13 @@ static bool io_grab_identity(struct io_kiocb *req)
 		spin_unlock_irq(&ctx->inflight_lock);
 		req->work.flags |= IO_WQ_WORK_FILES;
 	}
+	if (!(req->work.flags & IO_WQ_WORK_MM) &&
+	    (def->work_flags & IO_WQ_WORK_MM)) {
+		if (id->mm != current->mm)
+			return false;
+		mmgrab(id->mm);
+		req->work.flags |= IO_WQ_WORK_MM;
+	}
 
 	return true;
 }
@@ -1509,10 +1525,8 @@ static void io_prep_async_work(struct io_kiocb *req)
 {
 	const struct io_op_def *def = &io_op_defs[req->opcode];
 	struct io_ring_ctx *ctx = req->ctx;
-	struct io_identity *id;
 
 	io_req_init_async(req);
-	id = req->work.identity;
 
 	if (req->flags & REQ_F_FORCE_ASYNC)
 		req->work.flags |= IO_WQ_WORK_CONCURRENT;
@@ -1525,13 +1539,6 @@ static void io_prep_async_work(struct io_kiocb *req)
 			req->work.flags |= IO_WQ_WORK_UNBOUND;
 	}
 
-	/* ->mm can never change on us */
-	if (!(req->work.flags & IO_WQ_WORK_MM) &&
-	    (def->work_flags & IO_WQ_WORK_MM)) {
-		mmgrab(id->mm);
-		req->work.flags |= IO_WQ_WORK_MM;
-	}
-
 	/* if we fail grabbing identity, we must COW, regrab, and retry */
 	if (io_grab_identity(req))
 		return;
@@ -1633,19 +1640,38 @@ static void __io_queue_deferred(struct io_ring_ctx *ctx)
 
 static void io_flush_timeouts(struct io_ring_ctx *ctx)
 {
-	while (!list_empty(&ctx->timeout_list)) {
+	u32 seq;
+
+	if (list_empty(&ctx->timeout_list))
+		return;
+
+	seq = ctx->cached_cq_tail - atomic_read(&ctx->cq_timeouts);
+
+	do {
+		u32 events_needed, events_got;
 		struct io_kiocb *req = list_first_entry(&ctx->timeout_list,
 						struct io_kiocb, timeout.list);
 
 		if (io_is_timeout_noseq(req))
 			break;
-		if (req->timeout.target_seq != ctx->cached_cq_tail
-					- atomic_read(&ctx->cq_timeouts))
+
+		/*
+		 * Since seq can easily wrap around over time, subtract
+		 * the last seq at which timeouts were flushed before comparing.
+		 * Assuming not more than 2^31-1 events have happened since,
+		 * these subtractions won't have wrapped, so we can check if
+		 * target is in [last_seq, current_seq] by comparing the two.
+		 */
+		events_needed = req->timeout.target_seq - ctx->cq_last_tm_flush;
+		events_got = seq - ctx->cq_last_tm_flush;
+		if (events_got < events_needed)
 			break;
 
 		list_del_init(&req->timeout.list);
 		io_kill_timeout(req);
-	}
+	} while (!list_empty(&ctx->timeout_list));
+
+	ctx->cq_last_tm_flush = seq;
 }
 
 static void io_commit_cqring(struct io_ring_ctx *ctx)
@@ -1700,18 +1726,42 @@ static inline unsigned __io_cqring_events(struct io_ring_ctx *ctx)
 
 static void io_cqring_ev_posted(struct io_ring_ctx *ctx)
 {
+	/* see waitqueue_active() comment */
+	smp_mb();
+
 	if (waitqueue_active(&ctx->wait))
 		wake_up(&ctx->wait);
 	if (ctx->sq_data && waitqueue_active(&ctx->sq_data->wait))
 		wake_up(&ctx->sq_data->wait);
 	if (io_should_trigger_evfd(ctx))
 		eventfd_signal(ctx->cq_ev_fd, 1);
+	if (waitqueue_active(&ctx->cq_wait)) {
+		wake_up_interruptible(&ctx->cq_wait);
+		kill_fasync(&ctx->cq_fasync, SIGIO, POLL_IN);
+	}
+}
+
+static void io_cqring_ev_posted_iopoll(struct io_ring_ctx *ctx)
+{
+	/* see waitqueue_active() comment */
+	smp_mb();
+
+	if (ctx->flags & IORING_SETUP_SQPOLL) {
+		if (waitqueue_active(&ctx->wait))
+			wake_up(&ctx->wait);
+	}
+	if (io_should_trigger_evfd(ctx))
+		eventfd_signal(ctx->cq_ev_fd, 1);
+	if (waitqueue_active(&ctx->cq_wait)) {
+		wake_up_interruptible(&ctx->cq_wait);
+		kill_fasync(&ctx->cq_fasync, SIGIO, POLL_IN);
+	}
 }
 
 /* Returns true if there are no backlogged entries after the flush */
-static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
-				     struct task_struct *tsk,
-				     struct files_struct *files)
+static bool __io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
+				       struct task_struct *tsk,
+				       struct files_struct *files)
 {
 	struct io_rings *rings = ctx->rings;
 	struct io_kiocb *req, *tmp;
@@ -1764,6 +1814,20 @@ static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
 	return all_flushed;
 }
 
+static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
+				     struct task_struct *tsk,
+				     struct files_struct *files)
+{
+	if (test_bit(0, &ctx->cq_check_overflow)) {
+		/* iopoll syncs against uring_lock, not completion_lock */
+		if (ctx->flags & IORING_SETUP_IOPOLL)
+			mutex_lock(&ctx->uring_lock);
+		__io_cqring_overflow_flush(ctx, force, tsk, files);
+		if (ctx->flags & IORING_SETUP_IOPOLL)
+			mutex_unlock(&ctx->uring_lock);
+	}
+}
+
 static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
 {
 	struct io_ring_ctx *ctx = req->ctx;
@@ -2123,14 +2187,14 @@ static void __io_req_task_submit(struct io_kiocb *req)
 {
 	struct io_ring_ctx *ctx = req->ctx;
 
-	if (!__io_sq_thread_acquire_mm(ctx) &&
-	    !__io_sq_thread_acquire_files(ctx)) {
-		mutex_lock(&ctx->uring_lock);
+	mutex_lock(&ctx->uring_lock);
+	if (!ctx->sqo_dead &&
+	    !__io_sq_thread_acquire_mm(ctx) &&
+	    !__io_sq_thread_acquire_files(ctx))
 		__io_queue_sqe(req, NULL);
-		mutex_unlock(&ctx->uring_lock);
-	} else {
+	else
 		__io_req_task_cancel(req, -EFAULT);
-	}
+	mutex_unlock(&ctx->uring_lock);
 }
 
 static void io_req_task_submit(struct callback_head *cb)
@@ -2309,20 +2373,8 @@ static void io_double_put_req(struct io_kiocb *req)
 		io_free_req(req);
 }
 
-static unsigned io_cqring_events(struct io_ring_ctx *ctx, bool noflush)
+static unsigned io_cqring_events(struct io_ring_ctx *ctx)
 {
-	if (test_bit(0, &ctx->cq_check_overflow)) {
-		/*
-		 * noflush == true is from the waitqueue handler, just ensure
-		 * we wake up the task, and the next invocation will flush the
-		 * entries. We cannot safely to it from here.
-		 */
-		if (noflush)
-			return -1U;
-
-		io_cqring_overflow_flush(ctx, false, NULL, NULL);
-	}
-
 	/* See comment at the top of this file */
 	smp_rmb();
 	return __io_cqring_events(ctx);
@@ -2420,8 +2472,7 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
 	}
 
 	io_commit_cqring(ctx);
-	if (ctx->flags & IORING_SETUP_SQPOLL)
-		io_cqring_ev_posted(ctx);
+	io_cqring_ev_posted_iopoll(ctx);
 	io_req_free_batch_finish(ctx, &rb);
 
 	if (!list_empty(&again))
@@ -2547,7 +2598,9 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
 		 * If we do, we can potentially be spinning for commands that
 		 * already triggered a CQE (eg in error).
 		 */
-		if (io_cqring_events(ctx, false))
+		if (test_bit(0, &ctx->cq_check_overflow))
+			__io_cqring_overflow_flush(ctx, false, NULL, NULL);
+		if (io_cqring_events(ctx))
 			break;
 
 		/*
@@ -2664,6 +2717,8 @@ static bool io_rw_reissue(struct io_kiocb *req, long res)
 	if ((res != -EAGAIN && res != -EOPNOTSUPP) || io_wq_current_is_worker())
 		return false;
 
+	lockdep_assert_held(&req->ctx->uring_lock);
+
 	ret = io_sq_thread_acquire_mm_files(req->ctx, req);
 
 	if (io_resubmit_prep(req, ret)) {
@@ -5802,6 +5857,12 @@ static int io_timeout(struct io_kiocb *req)
 	tail = ctx->cached_cq_tail - atomic_read(&ctx->cq_timeouts);
 	req->timeout.target_seq = tail + off;
 
+	/* Update the last seq here in case io_flush_timeouts() hasn't.
+	 * This is safe because ->completion_lock is held, and submissions
+	 * and completions are never mixed in the same ->completion_lock section.
+	 */
+	ctx->cq_last_tm_flush = tail;
+
 	/*
 	 * Insertion sort, ensuring the first entry in the list is always
 	 * the one we need first.
@@ -6822,7 +6883,7 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
 
 	/* if we have a backlog and couldn't flush it all, return BUSY */
 	if (test_bit(0, &ctx->sq_check_overflow)) {
-		if (!io_cqring_overflow_flush(ctx, false, NULL, NULL))
+		if (!__io_cqring_overflow_flush(ctx, false, NULL, NULL))
 			return -EBUSY;
 	}
 
@@ -6924,7 +6985,8 @@ static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
 		if (!list_empty(&ctx->iopoll_list))
 			io_do_iopoll(ctx, &nr_events, 0);
 
-		if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)))
+		if (to_submit && !ctx->sqo_dead &&
+		    likely(!percpu_ref_is_dying(&ctx->refs)))
 			ret = io_submit_sqes(ctx, to_submit);
 		mutex_unlock(&ctx->uring_lock);
 	}
@@ -7025,6 +7087,7 @@ static int io_sq_thread(void *data)
 
 		if (sqt_spin || !time_after(jiffies, timeout)) {
 			io_run_task_work();
+			io_sq_thread_drop_mm_files();
 			cond_resched();
 			if (sqt_spin)
 				timeout = jiffies + sqd->sq_thread_idle;
@@ -7062,6 +7125,7 @@ static int io_sq_thread(void *data)
 	}
 
 	io_run_task_work();
+	io_sq_thread_drop_mm_files();
 
 	if (cur_css)
 		io_sq_thread_unassociate_blkcg();
@@ -7085,7 +7149,7 @@ struct io_wait_queue {
 	unsigned nr_timeouts;
 };
 
-static inline bool io_should_wake(struct io_wait_queue *iowq, bool noflush)
+static inline bool io_should_wake(struct io_wait_queue *iowq)
 {
 	struct io_ring_ctx *ctx = iowq->ctx;
 
@@ -7094,7 +7158,7 @@ static inline bool io_should_wake(struct io_wait_queue *iowq, bool noflush)
 	 * started waiting. For timeouts, we always want to return to userspace,
 	 * regardless of event count.
 	 */
-	return io_cqring_events(ctx, noflush) >= iowq->to_wait ||
+	return io_cqring_events(ctx) >= iowq->to_wait ||
 			atomic_read(&ctx->cq_timeouts) != iowq->nr_timeouts;
 }
 
@@ -7104,11 +7168,13 @@ static int io_wake_function(struct wait_queue_entry *curr, unsigned int mode,
 	struct io_wait_queue *iowq = container_of(curr, struct io_wait_queue,
 							wq);
 
-	/* use noflush == true, as we can't safely rely on locking context */
-	if (!io_should_wake(iowq, true))
-		return -1;
-
-	return autoremove_wake_function(curr, mode, wake_flags, key);
+	/*
+	 * Cannot safely flush overflowed CQEs from here, ensure we wake up
+	 * the task, and the next invocation will do it.
+	 */
+	if (io_should_wake(iowq) || test_bit(0, &iowq->ctx->cq_check_overflow))
+		return autoremove_wake_function(curr, mode, wake_flags, key);
+	return -1;
 }
 
 static int io_run_task_work_sig(void)
@@ -7145,7 +7211,8 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events,
 	int ret = 0;
 
 	do {
-		if (io_cqring_events(ctx, false) >= min_events)
+		io_cqring_overflow_flush(ctx, false, NULL, NULL);
+		if (io_cqring_events(ctx) >= min_events)
 			return 0;
 		if (!io_run_task_work())
 			break;
@@ -7173,6 +7240,7 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events,
 	iowq.nr_timeouts = atomic_read(&ctx->cq_timeouts);
 	trace_io_uring_cqring_wait(ctx, min_events);
 	do {
+		io_cqring_overflow_flush(ctx, false, NULL, NULL);
 		prepare_to_wait_exclusive(&ctx->wait, &iowq.wq,
 						TASK_INTERRUPTIBLE);
 		/* make sure we run task_work before checking for signals */
@@ -7181,8 +7249,10 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events,
 			continue;
 		else if (ret < 0)
 			break;
-		if (io_should_wake(&iowq, false))
+		if (io_should_wake(&iowq))
 			break;
+		if (test_bit(0, &ctx->cq_check_overflow))
+			continue;
 		if (uts) {
 			timeout = schedule_timeout(timeout);
 			if (timeout == 0) {
@@ -7231,14 +7301,28 @@ static void io_file_ref_kill(struct percpu_ref *ref)
 	complete(&data->done);
 }
 
+static void io_sqe_files_set_node(struct fixed_file_data *file_data,
+				  struct fixed_file_ref_node *ref_node)
+{
+	spin_lock_bh(&file_data->lock);
+	file_data->node = ref_node;
+	list_add_tail(&ref_node->node, &file_data->ref_list);
+	spin_unlock_bh(&file_data->lock);
+	percpu_ref_get(&file_data->refs);
+}
+
 static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 {
 	struct fixed_file_data *data = ctx->file_data;
-	struct fixed_file_ref_node *ref_node = NULL;
+	struct fixed_file_ref_node *backup_node, *ref_node = NULL;
 	unsigned nr_tables, i;
+	int ret;
 
 	if (!data)
 		return -ENXIO;
+	backup_node = alloc_fixed_file_ref_node(ctx);
+	if (!backup_node)
+		return -ENOMEM;
 
 	spin_lock_bh(&data->lock);
 	ref_node = data->node;
@@ -7250,7 +7334,18 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 
 	/* wait for all refs nodes to complete */
 	flush_delayed_work(&ctx->file_put_work);
-	wait_for_completion(&data->done);
+	do {
+		ret = wait_for_completion_interruptible(&data->done);
+		if (!ret)
+			break;
+		ret = io_run_task_work_sig();
+		if (ret < 0) {
+			percpu_ref_resurrect(&data->refs);
+			reinit_completion(&data->done);
+			io_sqe_files_set_node(data, backup_node);
+			return ret;
+		}
+	} while (1);
 
 	__io_sqe_files_unregister(ctx);
 	nr_tables = DIV_ROUND_UP(ctx->nr_user_files, IORING_MAX_FILES_TABLE);
@@ -7261,6 +7356,7 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 	kfree(data);
 	ctx->file_data = NULL;
 	ctx->nr_user_files = 0;
+	destroy_fixed_file_ref_node(backup_node);
 	return 0;
 }
 
@@ -7654,12 +7750,12 @@ static struct fixed_file_ref_node *alloc_fixed_file_ref_node(
 
 	ref_node = kzalloc(sizeof(*ref_node), GFP_KERNEL);
 	if (!ref_node)
-		return ERR_PTR(-ENOMEM);
+		return NULL;
 
 	if (percpu_ref_init(&ref_node->refs, io_file_data_ref_zero,
 			    0, GFP_KERNEL)) {
 		kfree(ref_node);
-		return ERR_PTR(-ENOMEM);
+		return NULL;
 	}
 	INIT_LIST_HEAD(&ref_node->node);
 	INIT_LIST_HEAD(&ref_node->file_list);
@@ -7753,16 +7849,12 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
 	}
 
 	ref_node = alloc_fixed_file_ref_node(ctx);
-	if (IS_ERR(ref_node)) {
+	if (!ref_node) {
 		io_sqe_files_unregister(ctx);
-		return PTR_ERR(ref_node);
+		return -ENOMEM;
 	}
 
-	file_data->node = ref_node;
-	spin_lock_bh(&file_data->lock);
-	list_add_tail(&ref_node->node, &file_data->ref_list);
-	spin_unlock_bh(&file_data->lock);
-	percpu_ref_get(&file_data->refs);
+	io_sqe_files_set_node(file_data, ref_node);
 	return ret;
 out_fput:
 	for (i = 0; i < ctx->nr_user_files; i++) {
@@ -7859,8 +7951,8 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
 		return -EINVAL;
 
 	ref_node = alloc_fixed_file_ref_node(ctx);
-	if (IS_ERR(ref_node))
-		return PTR_ERR(ref_node);
+	if (!ref_node)
+		return -ENOMEM;
 
 	done = 0;
 	fds = u64_to_user_ptr(up->fds);
@@ -7918,11 +8010,7 @@ static int __io_sqe_files_update(struct io_ring_ctx *ctx,
 
 	if (needs_switch) {
 		percpu_ref_kill(&data->node->refs);
-		spin_lock_bh(&data->lock);
-		list_add_tail(&ref_node->node, &data->ref_list);
-		data->node = ref_node;
-		spin_unlock_bh(&data->lock);
-		percpu_ref_get(&ctx->file_data->refs);
+		io_sqe_files_set_node(data, ref_node);
 	} else
 		destroy_fixed_file_ref_node(ref_node);
 
@@ -8602,7 +8690,8 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait)
 	smp_rmb();
 	if (!io_sqring_full(ctx))
 		mask |= EPOLLOUT | EPOLLWRNORM;
-	if (io_cqring_events(ctx, false))
+	io_cqring_overflow_flush(ctx, false, NULL, NULL);
+	if (io_cqring_events(ctx))
 		mask |= EPOLLIN | EPOLLRDNORM;
 
 	return mask;
@@ -8641,7 +8730,7 @@ static void io_ring_exit_work(struct work_struct *work)
 	 * as nobody else will be looking for them.
 	 */
 	do {
-		io_iopoll_try_reap_events(ctx);
+		__io_uring_cancel_task_requests(ctx, NULL);
 	} while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20));
 	io_ring_ctx_free(ctx);
 }
@@ -8657,10 +8746,14 @@ static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
 {
 	mutex_lock(&ctx->uring_lock);
 	percpu_ref_kill(&ctx->refs);
+
+	if (WARN_ON_ONCE((ctx->flags & IORING_SETUP_SQPOLL) && !ctx->sqo_dead))
+		ctx->sqo_dead = 1;
+
 	/* if force is set, the ring is going away. always drop after that */
 	ctx->cq_overflow_flushed = 1;
 	if (ctx->rings)
-		io_cqring_overflow_flush(ctx, true, NULL, NULL);
+		__io_cqring_overflow_flush(ctx, true, NULL, NULL);
 	mutex_unlock(&ctx->uring_lock);
 
 	io_kill_timeouts(ctx, NULL, NULL);
@@ -8796,9 +8889,11 @@ static void __io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 		enum io_wq_cancel cret;
 		bool ret = false;
 
-		cret = io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, &cancel, true);
-		if (cret != IO_WQ_CANCEL_NOTFOUND)
-			ret = true;
+		if (ctx->io_wq) {
+			cret = io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb,
+					       &cancel, true);
+			ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
+		}
 
 		/* SQPOLL thread does its own polling */
 		if (!(ctx->flags & IORING_SETUP_SQPOLL)) {
@@ -8817,6 +8912,19 @@ static void __io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 	}
 }
 
+static void io_disable_sqo_submit(struct io_ring_ctx *ctx)
+{
+	WARN_ON_ONCE(ctx->sqo_task != current);
+
+	mutex_lock(&ctx->uring_lock);
+	ctx->sqo_dead = 1;
+	mutex_unlock(&ctx->uring_lock);
+
+	/* make sure callers enter the ring to get error */
+	if (ctx->rings)
+		io_ring_set_wakeup_flag(ctx);
+}
+
 /*
  * We need to iteratively cancel requests, in case a request has dependent
  * hard links. These persist even for failure of cancelations, hence keep
@@ -8828,15 +8936,15 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 	struct task_struct *task = current;
 
 	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
+		/* for SQPOLL only sqo_task has task notes */
+		io_disable_sqo_submit(ctx);
 		task = ctx->sq_data->thread;
 		atomic_inc(&task->io_uring->in_idle);
 		io_sq_thread_park(ctx->sq_data);
 	}
 
 	io_cancel_defer_files(ctx, task, files);
-	io_ring_submit_lock(ctx, (ctx->flags & IORING_SETUP_IOPOLL));
 	io_cqring_overflow_flush(ctx, true, task, files);
-	io_ring_submit_unlock(ctx, (ctx->flags & IORING_SETUP_IOPOLL));
 
 	if (!files)
 		__io_uring_cancel_task_requests(ctx, task);
@@ -8909,20 +9017,12 @@ static void io_uring_del_task_file(struct file *file)
 		fput(file);
 }
 
-/*
- * Drop task note for this file if we're the only ones that hold it after
- * pending fput()
- */
-static void io_uring_attempt_task_drop(struct file *file)
+static void io_uring_remove_task_files(struct io_uring_task *tctx)
 {
-	if (!current->io_uring)
-		return;
-	/*
-	 * fput() is pending, will be 2 if the only other ref is our potential
-	 * task file note. If the task is exiting, drop regardless of count.
-	 */
-	if (fatal_signal_pending(current) || (current->flags & PF_EXITING) ||
-	    atomic_long_read(&file->f_count) == 2)
+	struct file *file;
+	unsigned long index;
+
+	xa_for_each(&tctx->xa, index, file)
 		io_uring_del_task_file(file);
 }
 
@@ -8934,16 +9034,12 @@ void __io_uring_files_cancel(struct files_struct *files)
 
 	/* make sure overflow events are dropped */
 	atomic_inc(&tctx->in_idle);
-
-	xa_for_each(&tctx->xa, index, file) {
-		struct io_ring_ctx *ctx = file->private_data;
-
-		io_uring_cancel_task_requests(ctx, files);
-		if (files)
-			io_uring_del_task_file(file);
-	}
-
+	xa_for_each(&tctx->xa, index, file)
+		io_uring_cancel_task_requests(file->private_data, files);
 	atomic_dec(&tctx->in_idle);
+
+	if (files)
+		io_uring_remove_task_files(tctx);
 }
 
 static s64 tctx_inflight(struct io_uring_task *tctx)
@@ -9005,12 +9101,41 @@ void __io_uring_task_cancel(void)
 		finish_wait(&tctx->wait, &wait);
 	} while (1);
 
+	finish_wait(&tctx->wait, &wait);
 	atomic_dec(&tctx->in_idle);
+
+	io_uring_remove_task_files(tctx);
 }
 
 static int io_uring_flush(struct file *file, void *data)
 {
-	io_uring_attempt_task_drop(file);
+	struct io_uring_task *tctx = current->io_uring;
+	struct io_ring_ctx *ctx = file->private_data;
+
+	if (!tctx)
+		return 0;
+
+	/* we should have cancelled and erased it before PF_EXITING */
+	WARN_ON_ONCE((current->flags & PF_EXITING) &&
+		     xa_load(&tctx->xa, (unsigned long)file));
+
+	/*
+	 * fput() is pending, will be 2 if the only other ref is our potential
+	 * task file note. If the task is exiting, drop regardless of count.
+	 */
+	if (atomic_long_read(&file->f_count) != 2)
+		return 0;
+
+	if (ctx->flags & IORING_SETUP_SQPOLL) {
+		/* there is only one file note, which is owned by sqo_task */
+		WARN_ON_ONCE((ctx->sqo_task == current) ==
+			     !xa_load(&tctx->xa, (unsigned long)file));
+
+		io_disable_sqo_submit(ctx);
+	}
+
+	if (!(ctx->flags & IORING_SETUP_SQPOLL) || ctx->sqo_task == current)
+		io_uring_del_task_file(file);
 	return 0;
 }
 
@@ -9084,8 +9209,9 @@ static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
 
 #endif /* !CONFIG_MMU */
 
-static void io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
+static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 {
+	int ret = 0;
 	DEFINE_WAIT(wait);
 
 	do {
@@ -9094,6 +9220,11 @@ static void io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 
 		prepare_to_wait(&ctx->sqo_sq_wait, &wait, TASK_INTERRUPTIBLE);
 
+		if (unlikely(ctx->sqo_dead)) {
+			ret = -EOWNERDEAD;
+			goto out;
+		}
+
 		if (!io_sqring_full(ctx))
 			break;
 
@@ -9101,6 +9232,8 @@ static void io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 	} while (!signal_pending(current));
 
 	finish_wait(&ctx->sqo_sq_wait, &wait);
+out:
+	return ret;
 }
 
 static int io_get_ext_arg(unsigned flags, const void __user *argp, size_t *argsz,
@@ -9172,17 +9305,18 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
 	 */
 	ret = 0;
 	if (ctx->flags & IORING_SETUP_SQPOLL) {
-		if (!list_empty_careful(&ctx->cq_overflow_list)) {
-			bool needs_lock = ctx->flags & IORING_SETUP_IOPOLL;
+		io_cqring_overflow_flush(ctx, false, NULL, NULL);
 
-			io_ring_submit_lock(ctx, needs_lock);
-			io_cqring_overflow_flush(ctx, false, NULL, NULL);
-			io_ring_submit_unlock(ctx, needs_lock);
-		}
+		ret = -EOWNERDEAD;
+		if (unlikely(ctx->sqo_dead))
+			goto out;
 		if (flags & IORING_ENTER_SQ_WAKEUP)
 			wake_up(&ctx->sq_data->wait);
-		if (flags & IORING_ENTER_SQ_WAIT)
-			io_sqpoll_wait_sq(ctx);
+		if (flags & IORING_ENTER_SQ_WAIT) {
+			ret = io_sqpoll_wait_sq(ctx);
+			if (ret)
+				goto out;
+		}
 		submitted = to_submit;
 	} else if (to_submit) {
 		ret = io_uring_add_task_file(ctx, f.file);
@@ -9601,6 +9735,7 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
 	 */
 	ret = io_uring_install_fd(ctx, file);
 	if (ret < 0) {
+		io_disable_sqo_submit(ctx);
 		/* fput will clean it up */
 		fput(file);
 		return ret;
@@ -9609,6 +9744,7 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
 	trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags);
 	return ret;
 err:
+	io_disable_sqo_submit(ctx);
 	io_ring_ctx_wait_and_kill(ctx);
 	return ret;
 }