summary refs log tree commit diff
path: root/fs/io_uring.c
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2022-01-12 10:20:35 -0800
committerLinus Torvalds <torvalds@linux-foundation.org>2022-01-12 10:20:35 -0800
commit42a7b4ed45e7667836fae4fb0e1ac6340588b1b0 (patch)
treebfb340853b854683e8889ccb9603219e75ca3b04 /fs/io_uring.c
parent7e7b69654724c72bd3219b71f58937845dca0b2b (diff)
parent3cc7fdb9f90a25ae92250bf9e6cf3b9556b230e9 (diff)
downloadlinux-42a7b4ed45e7667836fae4fb0e1ac6340588b1b0.tar.gz
Merge tag 'for-5.17/io_uring-2022-01-11' of git://git.kernel.dk/linux-block
Pull io_uring updates from Jens Axboe:

 - Support for prioritized work completions (Hao)

 - Simplification of reissue (Pavel)

 - Add support for CQE skip (Pavel)

 - Memory leak fix going to 5.15-stable (Pavel)

 - Re-write of internal poll. This both cleans up that code, and gets us
   ready to fix the POLLFREE issue (Pavel)

 - Various cleanups (GuoYong, Pavel, Hao)

* tag 'for-5.17/io_uring-2022-01-11' of git://git.kernel.dk/linux-block: (31 commits)
  io_uring: fix not released cached task refs
  io_uring: remove redundant tab space
  io_uring: remove unused function parameter
  io_uring: use completion batching for poll rem/upd
  io_uring: single shot poll removal optimisation
  io_uring: poll rework
  io_uring: kill poll linking optimisation
  io_uring: move common poll bits
  io_uring: refactor poll update
  io_uring: remove double poll on poll update
  io_uring: code clean for some ctx usage
  io_uring: batch completion in prior_task_list
  io_uring: split io_req_complete_post() and add a helper
  io_uring: add helper for task work execution code
  io_uring: add a priority tw list for irq completion work
  io-wq: add helper to merge two wq_lists
  io_uring: reuse io_req_task_complete for timeouts
  io_uring: tweak iopoll CQE_SKIP event counting
  io_uring: simplify selected buf handling
  io_uring: move up io_put_kbuf() and io_put_rw_kbuf()
  ...
Diffstat (limited to 'fs/io_uring.c')
-rw-r--r--fs/io_uring.c1124
1 files changed, 578 insertions, 546 deletions
diff --git a/fs/io_uring.c b/fs/io_uring.c
index fb2a0cb4aaf8..a4f20b8c74a4 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -108,7 +108,8 @@
 #define SQE_COMMON_FLAGS (IOSQE_FIXED_FILE | IOSQE_IO_LINK | \
 			  IOSQE_IO_HARDLINK | IOSQE_ASYNC)
 
-#define SQE_VALID_FLAGS	(SQE_COMMON_FLAGS|IOSQE_BUFFER_SELECT|IOSQE_IO_DRAIN)
+#define SQE_VALID_FLAGS	(SQE_COMMON_FLAGS | IOSQE_BUFFER_SELECT | \
+			IOSQE_IO_DRAIN | IOSQE_CQE_SKIP_SUCCESS)
 
 #define IO_REQ_CLEAN_FLAGS (REQ_F_BUFFER_SELECTED | REQ_F_NEED_CLEANUP | \
 				REQ_F_POLLED | REQ_F_INFLIGHT | REQ_F_CREDS | \
@@ -320,6 +321,7 @@ struct io_submit_state {
 
 	bool			plug_started;
 	bool			need_plug;
+	bool			flush_cqes;
 	unsigned short		submit_nr;
 	struct blk_plug		plug;
 };
@@ -337,6 +339,7 @@ struct io_ring_ctx {
 		unsigned int		restricted: 1;
 		unsigned int		off_timeout_used: 1;
 		unsigned int		drain_active: 1;
+		unsigned int		drain_disabled: 1;
 	} ____cacheline_aligned_in_smp;
 
 	/* submission data */
@@ -471,6 +474,7 @@ struct io_uring_task {
 
 	spinlock_t		task_lock;
 	struct io_wq_work_list	task_list;
+	struct io_wq_work_list	prior_task_list;
 	struct callback_head	task_work;
 	bool			task_running;
 };
@@ -483,8 +487,6 @@ struct io_poll_iocb {
 	struct file			*file;
 	struct wait_queue_head		*head;
 	__poll_t			events;
-	bool				done;
-	bool				canceled;
 	struct wait_queue_entry		wait;
 };
 
@@ -721,6 +723,7 @@ enum {
 	REQ_F_HARDLINK_BIT	= IOSQE_IO_HARDLINK_BIT,
 	REQ_F_FORCE_ASYNC_BIT	= IOSQE_ASYNC_BIT,
 	REQ_F_BUFFER_SELECT_BIT	= IOSQE_BUFFER_SELECT_BIT,
+	REQ_F_CQE_SKIP_BIT	= IOSQE_CQE_SKIP_SUCCESS_BIT,
 
 	/* first byte is taken by user flags, shift it to not overlap */
 	REQ_F_FAIL_BIT		= 8,
@@ -737,6 +740,7 @@ enum {
 	REQ_F_REFCOUNT_BIT,
 	REQ_F_ARM_LTIMEOUT_BIT,
 	REQ_F_ASYNC_DATA_BIT,
+	REQ_F_SKIP_LINK_CQES_BIT,
 	/* keep async read/write and isreg together and in order */
 	REQ_F_SUPPORT_NOWAIT_BIT,
 	REQ_F_ISREG_BIT,
@@ -758,6 +762,8 @@ enum {
 	REQ_F_FORCE_ASYNC	= BIT(REQ_F_FORCE_ASYNC_BIT),
 	/* IOSQE_BUFFER_SELECT */
 	REQ_F_BUFFER_SELECT	= BIT(REQ_F_BUFFER_SELECT_BIT),
+	/* IOSQE_CQE_SKIP_SUCCESS */
+	REQ_F_CQE_SKIP		= BIT(REQ_F_CQE_SKIP_BIT),
 
 	/* fail rest of links */
 	REQ_F_FAIL		= BIT(REQ_F_FAIL_BIT),
@@ -791,6 +797,8 @@ enum {
 	REQ_F_ARM_LTIMEOUT	= BIT(REQ_F_ARM_LTIMEOUT_BIT),
 	/* ->async_data allocated */
 	REQ_F_ASYNC_DATA	= BIT(REQ_F_ASYNC_DATA_BIT),
+	/* don't post CQEs while failing linked requests */
+	REQ_F_SKIP_LINK_CQES	= BIT(REQ_F_SKIP_LINK_CQES_BIT),
 };
 
 struct async_poll {
@@ -882,6 +890,7 @@ struct io_kiocb {
 	const struct cred		*creds;
 	/* stores selected buf, valid IFF REQ_F_BUFFER_SELECTED is set */
 	struct io_buffer		*kbuf;
+	atomic_t			poll_refs;
 };
 
 struct io_tctx_node {
@@ -1108,8 +1117,8 @@ static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 					 bool cancel_all);
 static void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 
-static bool io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
-				 s32 res, u32 cflags);
+static void io_fill_cqe_req(struct io_kiocb *req, s32 res, u32 cflags);
+
 static void io_put_req(struct io_kiocb *req);
 static void io_put_req_deferred(struct io_kiocb *req);
 static void io_dismantle_req(struct io_kiocb *req);
@@ -1264,6 +1273,26 @@ static inline void io_req_set_rsrc_node(struct io_kiocb *req,
 	}
 }
 
+static unsigned int __io_put_kbuf(struct io_kiocb *req)
+{
+	struct io_buffer *kbuf = req->kbuf;
+	unsigned int cflags;
+
+	cflags = kbuf->bid << IORING_CQE_BUFFER_SHIFT;
+	cflags |= IORING_CQE_F_BUFFER;
+	req->flags &= ~REQ_F_BUFFER_SELECTED;
+	kfree(kbuf);
+	req->kbuf = NULL;
+	return cflags;
+}
+
+static inline unsigned int io_put_kbuf(struct io_kiocb *req)
+{
+	if (likely(!(req->flags & REQ_F_BUFFER_SELECTED)))
+		return 0;
+	return __io_put_kbuf(req);
+}
+
 static void io_refs_resurrect(struct percpu_ref *ref, struct completion *compl)
 {
 	bool got = percpu_ref_tryget(ref);
@@ -1340,6 +1369,10 @@ static inline bool req_has_async_data(struct io_kiocb *req)
 static inline void req_set_fail(struct io_kiocb *req)
 {
 	req->flags |= REQ_F_FAIL;
+	if (req->flags & REQ_F_CQE_SKIP) {
+		req->flags &= ~REQ_F_CQE_SKIP;
+		req->flags |= REQ_F_SKIP_LINK_CQES;
+	}
 }
 
 static inline void req_fail_link_node(struct io_kiocb *req, int res)
@@ -1553,8 +1586,11 @@ static void io_prep_async_link(struct io_kiocb *req)
 
 static inline void io_req_add_compl_list(struct io_kiocb *req)
 {
-	struct io_submit_state *state = &req->ctx->submit_state;
+	struct io_ring_ctx *ctx = req->ctx;
+	struct io_submit_state *state = &ctx->submit_state;
 
+	if (!(req->flags & REQ_F_CQE_SKIP))
+		ctx->submit_state.flush_cqes = true;
 	wq_list_add_tail(&req->comp_list, &state->compl_reqs);
 }
 
@@ -1599,7 +1635,7 @@ static void io_kill_timeout(struct io_kiocb *req, int status)
 		atomic_set(&req->ctx->cq_timeouts,
 			atomic_read(&req->ctx->cq_timeouts) + 1);
 		list_del_init(&req->timeout.list);
-		io_cqring_fill_event(req->ctx, req->user_data, status, 0);
+		io_fill_cqe_req(req, status, 0);
 		io_put_req_deferred(req);
 	}
 }
@@ -1830,6 +1866,18 @@ static inline void io_get_task_refs(int nr)
 		io_task_refs_refill(tctx);
 }
 
+static __cold void io_uring_drop_tctx_refs(struct task_struct *task)
+{
+	struct io_uring_task *tctx = task->io_uring;
+	unsigned int refs = tctx->cached_refs;
+
+	if (refs) {
+		tctx->cached_refs = 0;
+		percpu_counter_sub(&tctx->inflight, refs);
+		put_task_struct_many(task, refs);
+	}
+}
+
 static bool io_cqring_event_overflow(struct io_ring_ctx *ctx, u64 user_data,
 				     s32 res, u32 cflags)
 {
@@ -1858,8 +1906,8 @@ static bool io_cqring_event_overflow(struct io_ring_ctx *ctx, u64 user_data,
 	return true;
 }
 
-static inline bool __io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
-					  s32 res, u32 cflags)
+static inline bool __io_fill_cqe(struct io_ring_ctx *ctx, u64 user_data,
+				 s32 res, u32 cflags)
 {
 	struct io_uring_cqe *cqe;
 
@@ -1880,20 +1928,26 @@ static inline bool __io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data
 	return io_cqring_event_overflow(ctx, user_data, res, cflags);
 }
 
-/* not as hot to bloat with inlining */
-static noinline bool io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
-					  s32 res, u32 cflags)
+static noinline void io_fill_cqe_req(struct io_kiocb *req, s32 res, u32 cflags)
 {
-	return __io_cqring_fill_event(ctx, user_data, res, cflags);
+	if (!(req->flags & REQ_F_CQE_SKIP))
+		__io_fill_cqe(req->ctx, req->user_data, res, cflags);
 }
 
-static void io_req_complete_post(struct io_kiocb *req, s32 res,
-				 u32 cflags)
+static noinline bool io_fill_cqe_aux(struct io_ring_ctx *ctx, u64 user_data,
+				     s32 res, u32 cflags)
+{
+	ctx->cq_extra++;
+	return __io_fill_cqe(ctx, user_data, res, cflags);
+}
+
+static void __io_req_complete_post(struct io_kiocb *req, s32 res,
+				   u32 cflags)
 {
 	struct io_ring_ctx *ctx = req->ctx;
 
-	spin_lock(&ctx->completion_lock);
-	__io_cqring_fill_event(ctx, req->user_data, res, cflags);
+	if (!(req->flags & REQ_F_CQE_SKIP))
+		__io_fill_cqe(ctx, req->user_data, res, cflags);
 	/*
 	 * If we're the last reference to this request, add to our locked
 	 * free_list cache.
@@ -1913,6 +1967,15 @@ static void io_req_complete_post(struct io_kiocb *req, s32 res,
 		wq_list_add_head(&req->comp_list, &ctx->locked_free_list);
 		ctx->locked_free_nr++;
 	}
+}
+
+static void io_req_complete_post(struct io_kiocb *req, s32 res,
+				 u32 cflags)
+{
+	struct io_ring_ctx *ctx = req->ctx;
+
+	spin_lock(&ctx->completion_lock);
+	__io_req_complete_post(req, res, cflags);
 	io_commit_cqring(ctx);
 	spin_unlock(&ctx->completion_lock);
 	io_cqring_ev_posted(ctx);
@@ -2101,8 +2164,8 @@ static bool io_kill_linked_timeout(struct io_kiocb *req)
 		link->timeout.head = NULL;
 		if (hrtimer_try_to_cancel(&io->timer) != -1) {
 			list_del(&link->timeout.list);
-			io_cqring_fill_event(link->ctx, link->user_data,
-					     -ECANCELED, 0);
+			/* leave REQ_F_CQE_SKIP to io_fill_cqe_req */
+			io_fill_cqe_req(link, -ECANCELED, 0);
 			io_put_req_deferred(link);
 			return true;
 		}
@@ -2114,6 +2177,7 @@ static void io_fail_links(struct io_kiocb *req)
 	__must_hold(&req->ctx->completion_lock)
 {
 	struct io_kiocb *nxt, *link = req->link;
+	bool ignore_cqes = req->flags & REQ_F_SKIP_LINK_CQES;
 
 	req->link = NULL;
 	while (link) {
@@ -2126,7 +2190,10 @@ static void io_fail_links(struct io_kiocb *req)
 		link->link = NULL;
 
 		trace_io_uring_fail_link(req, link);
-		io_cqring_fill_event(link->ctx, link->user_data, res, 0);
+		if (!ignore_cqes) {
+			link->flags &= ~REQ_F_CQE_SKIP;
+			io_fill_cqe_req(link, res, 0);
+		}
 		io_put_req_deferred(link);
 		link = nxt;
 	}
@@ -2143,8 +2210,8 @@ static bool io_disarm_next(struct io_kiocb *req)
 		req->flags &= ~REQ_F_ARM_LTIMEOUT;
 		if (link && link->opcode == IORING_OP_LINK_TIMEOUT) {
 			io_remove_next_linked(req);
-			io_cqring_fill_event(link->ctx, link->user_data,
-					     -ECANCELED, 0);
+			/* leave REQ_F_CQE_SKIP to io_fill_cqe_req */
+			io_fill_cqe_req(link, -ECANCELED, 0);
 			io_put_req_deferred(link);
 			posted = true;
 		}
@@ -2171,7 +2238,7 @@ static void __io_req_find_next_prep(struct io_kiocb *req)
 	spin_lock(&ctx->completion_lock);
 	posted = io_disarm_next(req);
 	if (posted)
-		io_commit_cqring(req->ctx);
+		io_commit_cqring(ctx);
 	spin_unlock(&ctx->completion_lock);
 	if (posted)
 		io_cqring_ev_posted(ctx);
@@ -2208,51 +2275,108 @@ static void ctx_flush_and_put(struct io_ring_ctx *ctx, bool *locked)
 	percpu_ref_put(&ctx->refs);
 }
 
+static inline void ctx_commit_and_unlock(struct io_ring_ctx *ctx)
+{
+	io_commit_cqring(ctx);
+	spin_unlock(&ctx->completion_lock);
+	io_cqring_ev_posted(ctx);
+}
+
+static void handle_prev_tw_list(struct io_wq_work_node *node,
+				struct io_ring_ctx **ctx, bool *uring_locked)
+{
+	if (*ctx && !*uring_locked)
+		spin_lock(&(*ctx)->completion_lock);
+
+	do {
+		struct io_wq_work_node *next = node->next;
+		struct io_kiocb *req = container_of(node, struct io_kiocb,
+						    io_task_work.node);
+
+		if (req->ctx != *ctx) {
+			if (unlikely(!*uring_locked && *ctx))
+				ctx_commit_and_unlock(*ctx);
+
+			ctx_flush_and_put(*ctx, uring_locked);
+			*ctx = req->ctx;
+			/* if not contended, grab and improve batching */
+			*uring_locked = mutex_trylock(&(*ctx)->uring_lock);
+			percpu_ref_get(&(*ctx)->refs);
+			if (unlikely(!*uring_locked))
+				spin_lock(&(*ctx)->completion_lock);
+		}
+		if (likely(*uring_locked))
+			req->io_task_work.func(req, uring_locked);
+		else
+			__io_req_complete_post(req, req->result, io_put_kbuf(req));
+		node = next;
+	} while (node);
+
+	if (unlikely(!*uring_locked))
+		ctx_commit_and_unlock(*ctx);
+}
+
+static void handle_tw_list(struct io_wq_work_node *node,
+			   struct io_ring_ctx **ctx, bool *locked)
+{
+	do {
+		struct io_wq_work_node *next = node->next;
+		struct io_kiocb *req = container_of(node, struct io_kiocb,
+						    io_task_work.node);
+
+		if (req->ctx != *ctx) {
+			ctx_flush_and_put(*ctx, locked);
+			*ctx = req->ctx;
+			/* if not contended, grab and improve batching */
+			*locked = mutex_trylock(&(*ctx)->uring_lock);
+			percpu_ref_get(&(*ctx)->refs);
+		}
+		req->io_task_work.func(req, locked);
+		node = next;
+	} while (node);
+}
+
 static void tctx_task_work(struct callback_head *cb)
 {
-	bool locked = false;
+	bool uring_locked = false;
 	struct io_ring_ctx *ctx = NULL;
 	struct io_uring_task *tctx = container_of(cb, struct io_uring_task,
 						  task_work);
 
 	while (1) {
-		struct io_wq_work_node *node;
+		struct io_wq_work_node *node1, *node2;
 
-		if (!tctx->task_list.first && locked)
+		if (!tctx->task_list.first &&
+		    !tctx->prior_task_list.first && uring_locked)
 			io_submit_flush_completions(ctx);
 
 		spin_lock_irq(&tctx->task_lock);
-		node = tctx->task_list.first;
+		node1 = tctx->prior_task_list.first;
+		node2 = tctx->task_list.first;
 		INIT_WQ_LIST(&tctx->task_list);
-		if (!node)
+		INIT_WQ_LIST(&tctx->prior_task_list);
+		if (!node2 && !node1)
 			tctx->task_running = false;
 		spin_unlock_irq(&tctx->task_lock);
-		if (!node)
+		if (!node2 && !node1)
 			break;
 
-		do {
-			struct io_wq_work_node *next = node->next;
-			struct io_kiocb *req = container_of(node, struct io_kiocb,
-							    io_task_work.node);
-
-			if (req->ctx != ctx) {
-				ctx_flush_and_put(ctx, &locked);
-				ctx = req->ctx;
-				/* if not contended, grab and improve batching */
-				locked = mutex_trylock(&ctx->uring_lock);
-				percpu_ref_get(&ctx->refs);
-			}
-			req->io_task_work.func(req, &locked);
-			node = next;
-		} while (node);
+		if (node1)
+			handle_prev_tw_list(node1, &ctx, &uring_locked);
 
+		if (node2)
+			handle_tw_list(node2, &ctx, &uring_locked);
 		cond_resched();
 	}
 
-	ctx_flush_and_put(ctx, &locked);
+	ctx_flush_and_put(ctx, &uring_locked);
+
+	/* relaxed read is enough as only the task itself sets ->in_idle */
+	if (unlikely(atomic_read(&tctx->in_idle)))
+		io_uring_drop_tctx_refs(current);
 }
 
-static void io_req_task_work_add(struct io_kiocb *req)
+static void io_req_task_work_add(struct io_kiocb *req, bool priority)
 {
 	struct task_struct *tsk = req->task;
 	struct io_uring_task *tctx = tsk->io_uring;
@@ -2264,7 +2388,10 @@ static void io_req_task_work_add(struct io_kiocb *req)
 	WARN_ON_ONCE(!tctx);
 
 	spin_lock_irqsave(&tctx->task_lock, flags);
-	wq_list_add_tail(&req->io_task_work.node, &tctx->task_list);
+	if (priority)
+		wq_list_add_tail(&req->io_task_work.node, &tctx->prior_task_list);
+	else
+		wq_list_add_tail(&req->io_task_work.node, &tctx->task_list);
 	running = tctx->task_running;
 	if (!running)
 		tctx->task_running = true;
@@ -2289,8 +2416,7 @@ static void io_req_task_work_add(struct io_kiocb *req)
 
 	spin_lock_irqsave(&tctx->task_lock, flags);
 	tctx->task_running = false;
-	node = tctx->task_list.first;
-	INIT_WQ_LIST(&tctx->task_list);
+	node = wq_list_merge(&tctx->prior_task_list, &tctx->task_list);
 	spin_unlock_irqrestore(&tctx->task_lock, flags);
 
 	while (node) {
@@ -2327,19 +2453,19 @@ static void io_req_task_queue_fail(struct io_kiocb *req, int ret)
 {
 	req->result = ret;
 	req->io_task_work.func = io_req_task_cancel;
-	io_req_task_work_add(req);
+	io_req_task_work_add(req, false);
 }
 
 static void io_req_task_queue(struct io_kiocb *req)
 {
 	req->io_task_work.func = io_req_task_submit;
-	io_req_task_work_add(req);
+	io_req_task_work_add(req, false);
 }
 
 static void io_req_task_queue_reissue(struct io_kiocb *req)
 {
 	req->io_task_work.func = io_queue_async_work;
-	io_req_task_work_add(req);
+	io_req_task_work_add(req, false);
 }
 
 static inline void io_queue_next(struct io_kiocb *req)
@@ -2403,17 +2529,22 @@ static void __io_submit_flush_completions(struct io_ring_ctx *ctx)
 	struct io_wq_work_node *node, *prev;
 	struct io_submit_state *state = &ctx->submit_state;
 
-	spin_lock(&ctx->completion_lock);
-	wq_list_for_each(node, prev, &state->compl_reqs) {
-		struct io_kiocb *req = container_of(node, struct io_kiocb,
+	if (state->flush_cqes) {
+		spin_lock(&ctx->completion_lock);
+		wq_list_for_each(node, prev, &state->compl_reqs) {
+			struct io_kiocb *req = container_of(node, struct io_kiocb,
 						    comp_list);
 
-		__io_cqring_fill_event(ctx, req->user_data, req->result,
-					req->cflags);
+			if (!(req->flags & REQ_F_CQE_SKIP))
+				__io_fill_cqe(ctx, req->user_data, req->result,
+					      req->cflags);
+		}
+
+		io_commit_cqring(ctx);
+		spin_unlock(&ctx->completion_lock);
+		io_cqring_ev_posted(ctx);
+		state->flush_cqes = false;
 	}
-	io_commit_cqring(ctx);
-	spin_unlock(&ctx->completion_lock);
-	io_cqring_ev_posted(ctx);
 
 	io_free_batch_list(ctx, state->compl_reqs.first);
 	INIT_WQ_LIST(&state->compl_reqs);
@@ -2444,7 +2575,7 @@ static inline void io_put_req_deferred(struct io_kiocb *req)
 {
 	if (req_ref_put_and_test(req)) {
 		req->io_task_work.func = io_free_req_work;
-		io_req_task_work_add(req);
+		io_req_task_work_add(req, false);
 	}
 }
 
@@ -2463,24 +2594,6 @@ static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
 	return smp_load_acquire(&rings->sq.tail) - ctx->cached_sq_head;
 }
 
-static unsigned int io_put_kbuf(struct io_kiocb *req, struct io_buffer *kbuf)
-{
-	unsigned int cflags;
-
-	cflags = kbuf->bid << IORING_CQE_BUFFER_SHIFT;
-	cflags |= IORING_CQE_F_BUFFER;
-	req->flags &= ~REQ_F_BUFFER_SELECTED;
-	kfree(kbuf);
-	return cflags;
-}
-
-static inline unsigned int io_put_rw_kbuf(struct io_kiocb *req)
-{
-	if (likely(!(req->flags & REQ_F_BUFFER_SELECTED)))
-		return 0;
-	return io_put_kbuf(req, req->kbuf);
-}
-
 static inline bool io_run_task_work(void)
 {
 	if (test_thread_flag(TIF_NOTIFY_SIGNAL) || current->task_works) {
@@ -2543,8 +2656,10 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin)
 		/* order with io_complete_rw_iopoll(), e.g. ->result updates */
 		if (!smp_load_acquire(&req->iopoll_completed))
 			break;
-		__io_cqring_fill_event(ctx, req->user_data, req->result,
-					io_put_rw_kbuf(req));
+		if (unlikely(req->flags & REQ_F_CQE_SKIP))
+			continue;
+
+		__io_fill_cqe(ctx, req->user_data, req->result, io_put_kbuf(req));
 		nr_events++;
 	}
 
@@ -2718,9 +2833,9 @@ static bool __io_complete_rw_common(struct io_kiocb *req, long res)
 	return false;
 }
 
-static void io_req_task_complete(struct io_kiocb *req, bool *locked)
+static inline void io_req_task_complete(struct io_kiocb *req, bool *locked)
 {
-	unsigned int cflags = io_put_rw_kbuf(req);
+	unsigned int cflags = io_put_kbuf(req);
 	int res = req->result;
 
 	if (*locked) {
@@ -2731,12 +2846,12 @@ static void io_req_task_complete(struct io_kiocb *req, bool *locked)
 	}
 }
 
-static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
+static void __io_complete_rw(struct io_kiocb *req, long res,
 			     unsigned int issue_flags)
 {
 	if (__io_complete_rw_common(req, res))
 		return;
-	__io_req_complete(req, issue_flags, req->result, io_put_rw_kbuf(req));
+	__io_req_complete(req, issue_flags, req->result, io_put_kbuf(req));
 }
 
 static void io_complete_rw(struct kiocb *kiocb, long res)
@@ -2747,7 +2862,7 @@ static void io_complete_rw(struct kiocb *kiocb, long res)
 		return;
 	req->result = res;
 	req->io_task_work.func = io_req_task_complete;
-	io_req_task_work_add(req);
+	io_req_task_work_add(req, !!(req->ctx->flags & IORING_SETUP_SQPOLL));
 }
 
 static void io_complete_rw_iopoll(struct kiocb *kiocb, long res)
@@ -2965,10 +3080,9 @@ static inline void io_rw_done(struct kiocb *kiocb, ssize_t ret)
 	}
 }
 
-static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
+static void kiocb_done(struct io_kiocb *req, ssize_t ret,
 		       unsigned int issue_flags)
 {
-	struct io_kiocb *req = container_of(kiocb, struct io_kiocb, rw.kiocb);
 	struct io_async_rw *io = req->async_data;
 
 	/* add previously done IO, if any */
@@ -2980,28 +3094,21 @@ static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
 	}
 
 	if (req->flags & REQ_F_CUR_POS)
-		req->file->f_pos = kiocb->ki_pos;
-	if (ret >= 0 && (kiocb->ki_complete == io_complete_rw))
-		__io_complete_rw(req, ret, 0, issue_flags);
+		req->file->f_pos = req->rw.kiocb.ki_pos;
+	if (ret >= 0 && (req->rw.kiocb.ki_complete == io_complete_rw))
+		__io_complete_rw(req, ret, issue_flags);
 	else
-		io_rw_done(kiocb, ret);
+		io_rw_done(&req->rw.kiocb, ret);
 
 	if (req->flags & REQ_F_REISSUE) {
 		req->flags &= ~REQ_F_REISSUE;
 		if (io_resubmit_prep(req)) {
 			io_req_task_queue_reissue(req);
 		} else {
-			unsigned int cflags = io_put_rw_kbuf(req);
-			struct io_ring_ctx *ctx = req->ctx;
-
 			req_set_fail(req);
-			if (issue_flags & IO_URING_F_UNLOCKED) {
-				mutex_lock(&ctx->uring_lock);
-				__io_req_complete(req, issue_flags, ret, cflags);
-				mutex_unlock(&ctx->uring_lock);
-			} else {
-				__io_req_complete(req, issue_flags, ret, cflags);
-			}
+			req->result = ret;
+			req->io_task_work.func = io_req_task_complete;
+			io_req_task_work_add(req, false);
 		}
 	}
 }
@@ -3229,10 +3336,12 @@ static struct iovec *__io_import_iovec(int rw, struct io_kiocb *req,
 	size_t sqe_len;
 	ssize_t ret;
 
-	BUILD_BUG_ON(ERR_PTR(0) != NULL);
-
-	if (opcode == IORING_OP_READ_FIXED || opcode == IORING_OP_WRITE_FIXED)
-		return ERR_PTR(io_import_fixed(req, rw, iter));
+	if (opcode == IORING_OP_READ_FIXED || opcode == IORING_OP_WRITE_FIXED) {
+		ret = io_import_fixed(req, rw, iter);
+		if (ret)
+			return ERR_PTR(ret);
+		return NULL;
+	}
 
 	/* buffer index only valid with fixed read/write, or buffer select  */
 	if (unlikely(req->buf_index && !(req->flags & REQ_F_BUFFER_SELECT)))
@@ -3250,15 +3359,18 @@ static struct iovec *__io_import_iovec(int rw, struct io_kiocb *req,
 		}
 
 		ret = import_single_range(rw, buf, sqe_len, s->fast_iov, iter);
-		return ERR_PTR(ret);
+		if (ret)
+			return ERR_PTR(ret);
+		return NULL;
 	}
 
 	iovec = s->fast_iov;
 	if (req->flags & REQ_F_BUFFER_SELECT) {
 		ret = io_iov_buffer_select(req, iovec, issue_flags);
-		if (!ret)
-			iov_iter_init(iter, rw, iovec, 1, iovec->iov_len);
-		return ERR_PTR(ret);
+		if (ret)
+			return ERR_PTR(ret);
+		iov_iter_init(iter, rw, iovec, 1, iovec->iov_len);
+		return NULL;
 	}
 
 	ret = __import_iovec(rw, buf, sqe_len, UIO_FASTIOV, &iovec, iter,
@@ -3629,7 +3741,7 @@ static int io_read(struct io_kiocb *req, unsigned int issue_flags)
 		iov_iter_restore(&s->iter, &s->iter_state);
 	} while (ret > 0);
 done:
-	kiocb_done(kiocb, ret, issue_flags);
+	kiocb_done(req, ret, issue_flags);
 out_free:
 	/* it's faster to check here then delegate to kfree */
 	if (iovec)
@@ -3726,7 +3838,7 @@ static int io_write(struct io_kiocb *req, unsigned int issue_flags)
 		if (ret2 == -EAGAIN && (req->ctx->flags & IORING_SETUP_IOPOLL))
 			goto copy_iov;
 done:
-		kiocb_done(kiocb, ret2, issue_flags);
+		kiocb_done(req, ret2, issue_flags);
 	} else {
 copy_iov:
 		iov_iter_restore(&s->iter, &s->iter_state);
@@ -4839,17 +4951,18 @@ static int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
 		min_ret = iov_iter_count(&kmsg->msg.msg_iter);
 
 	ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
-	if ((issue_flags & IO_URING_F_NONBLOCK) && ret == -EAGAIN)
-		return io_setup_async_msg(req, kmsg);
-	if (ret == -ERESTARTSYS)
-		ret = -EINTR;
 
+	if (ret < min_ret) {
+		if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
+			return io_setup_async_msg(req, kmsg);
+		if (ret == -ERESTARTSYS)
+			ret = -EINTR;
+		req_set_fail(req);
+	}
 	/* fast path, check for non-NULL to avoid function call */
 	if (kmsg->free_iov)
 		kfree(kmsg->free_iov);
 	req->flags &= ~REQ_F_NEED_CLEANUP;
-	if (ret < min_ret)
-		req_set_fail(req);
 	__io_req_complete(req, issue_flags, ret, 0);
 	return 0;
 }
@@ -4885,13 +4998,13 @@ static int io_send(struct io_kiocb *req, unsigned int issue_flags)
 
 	msg.msg_flags = flags;
 	ret = sock_sendmsg(sock, &msg);
-	if ((issue_flags & IO_URING_F_NONBLOCK) && ret == -EAGAIN)
-		return -EAGAIN;
-	if (ret == -ERESTARTSYS)
-		ret = -EINTR;
-
-	if (ret < min_ret)
+	if (ret < min_ret) {
+		if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
+			return -EAGAIN;
+		if (ret == -ERESTARTSYS)
+			ret = -EINTR;
 		req_set_fail(req);
+	}
 	__io_req_complete(req, issue_flags, ret, 0);
 	return 0;
 }
@@ -4991,11 +5104,6 @@ static struct io_buffer *io_recv_buffer_select(struct io_kiocb *req,
 	return io_buffer_select(req, &sr->len, sr->bgid, issue_flags);
 }
 
-static inline unsigned int io_put_recv_kbuf(struct io_kiocb *req)
-{
-	return io_put_kbuf(req, req->kbuf);
-}
-
 static int io_recvmsg_prep_async(struct io_kiocb *req)
 {
 	int ret;
@@ -5033,8 +5141,7 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 	struct socket *sock;
 	struct io_buffer *kbuf;
 	unsigned flags;
-	int min_ret = 0;
-	int ret, cflags = 0;
+	int ret, min_ret = 0;
 	bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
 
 	sock = sock_from_file(req->file);
@@ -5068,20 +5175,21 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 
 	ret = __sys_recvmsg_sock(sock, &kmsg->msg, req->sr_msg.umsg,
 					kmsg->uaddr, flags);
-	if (force_nonblock && ret == -EAGAIN)
-		return io_setup_async_msg(req, kmsg);
-	if (ret == -ERESTARTSYS)
-		ret = -EINTR;
+	if (ret < min_ret) {
+		if (ret == -EAGAIN && force_nonblock)
+			return io_setup_async_msg(req, kmsg);
+		if (ret == -ERESTARTSYS)
+			ret = -EINTR;
+		req_set_fail(req);
+	} else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
+		req_set_fail(req);
+	}
 
-	if (req->flags & REQ_F_BUFFER_SELECTED)
-		cflags = io_put_recv_kbuf(req);
 	/* fast path, check for non-NULL to avoid function call */
 	if (kmsg->free_iov)
 		kfree(kmsg->free_iov);
 	req->flags &= ~REQ_F_NEED_CLEANUP;
-	if (ret < min_ret || ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
-		req_set_fail(req);
-	__io_req_complete(req, issue_flags, ret, cflags);
+	__io_req_complete(req, issue_flags, ret, io_put_kbuf(req));
 	return 0;
 }
 
@@ -5094,8 +5202,7 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
 	struct socket *sock;
 	struct iovec iov;
 	unsigned flags;
-	int min_ret = 0;
-	int ret, cflags = 0;
+	int ret, min_ret = 0;
 	bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
 
 	sock = sock_from_file(req->file);
@@ -5127,16 +5234,18 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
 		min_ret = iov_iter_count(&msg.msg_iter);
 
 	ret = sock_recvmsg(sock, &msg, flags);
-	if (force_nonblock && ret == -EAGAIN)
-		return -EAGAIN;
-	if (ret == -ERESTARTSYS)
-		ret = -EINTR;
 out_free:
-	if (req->flags & REQ_F_BUFFER_SELECTED)
-		cflags = io_put_recv_kbuf(req);
-	if (ret < min_ret || ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))))
+	if (ret < min_ret) {
+		if (ret == -EAGAIN && force_nonblock)
+			return -EAGAIN;
+		if (ret == -ERESTARTSYS)
+			ret = -EINTR;
+		req_set_fail(req);
+	} else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
 		req_set_fail(req);
-	__io_req_complete(req, issue_flags, ret, cflags);
+	}
+
+	__io_req_complete(req, issue_flags, ret, io_put_kbuf(req));
 	return 0;
 }
 
@@ -5303,52 +5412,23 @@ struct io_poll_table {
 	int error;
 };
 
-static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
-			   __poll_t mask, io_req_tw_func_t func)
-{
-	/* for instances that support it check for an event match first: */
-	if (mask && !(mask & poll->events))
-		return 0;
-
-	trace_io_uring_task_add(req->ctx, req->opcode, req->user_data, mask);
-
-	list_del_init(&poll->wait.entry);
-
-	req->result = mask;
-	req->io_task_work.func = func;
+#define IO_POLL_CANCEL_FLAG	BIT(31)
+#define IO_POLL_REF_MASK	((1u << 20)-1)
 
-	/*
-	 * If this fails, then the task is exiting. When a task exits, the
-	 * work gets canceled, so just cancel this request as well instead
-	 * of executing it. We can't safely execute it anyway, as we may not
-	 * have the needed state needed for it anyway.
-	 */
-	io_req_task_work_add(req);
-	return 1;
+/*
+ * If refs part of ->poll_refs (see IO_POLL_REF_MASK) is 0, it's free. We can
+ * bump it and acquire ownership. It's disallowed to modify requests while not
+ * owning it, that prevents from races for enqueueing task_work's and b/w
+ * arming poll and wakeups.
+ */
+static inline bool io_poll_get_ownership(struct io_kiocb *req)
+{
+	return !(atomic_fetch_inc(&req->poll_refs) & IO_POLL_REF_MASK);
 }
 
-static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
-	__acquires(&req->ctx->completion_lock)
+static void io_poll_mark_cancelled(struct io_kiocb *req)
 {
-	struct io_ring_ctx *ctx = req->ctx;
-
-	/* req->task == current here, checking PF_EXITING is safe */
-	if (unlikely(req->task->flags & PF_EXITING))
-		WRITE_ONCE(poll->canceled, true);
-
-	if (!req->result && !READ_ONCE(poll->canceled)) {
-		struct poll_table_struct pt = { ._key = poll->events };
-
-		req->result = vfs_poll(req->file, &pt) & poll->events;
-	}
-
-	spin_lock(&ctx->completion_lock);
-	if (!req->result && !READ_ONCE(poll->canceled)) {
-		add_wait_queue(poll->head, &poll->wait);
-		return true;
-	}
-
-	return false;
+	atomic_or(IO_POLL_CANCEL_FLAG, &req->poll_refs);
 }
 
 static struct io_poll_iocb *io_poll_get_double(struct io_kiocb *req)
@@ -5366,133 +5446,199 @@ static struct io_poll_iocb *io_poll_get_single(struct io_kiocb *req)
 	return &req->apoll->poll;
 }
 
-static void io_poll_remove_double(struct io_kiocb *req)
-	__must_hold(&req->ctx->completion_lock)
+static void io_poll_req_insert(struct io_kiocb *req)
 {
-	struct io_poll_iocb *poll = io_poll_get_double(req);
+	struct io_ring_ctx *ctx = req->ctx;
+	struct hlist_head *list;
 
-	lockdep_assert_held(&req->ctx->completion_lock);
+	list = &ctx->cancel_hash[hash_long(req->user_data, ctx->cancel_hash_bits)];
+	hlist_add_head(&req->hash_node, list);
+}
 
-	if (poll && poll->head) {
-		struct wait_queue_head *head = poll->head;
+static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events,
+			      wait_queue_func_t wake_func)
+{
+	poll->head = NULL;
+#define IO_POLL_UNMASK	(EPOLLERR|EPOLLHUP|EPOLLNVAL|EPOLLRDHUP)
+	/* mask in events that we always want/need */
+	poll->events = events | IO_POLL_UNMASK;
+	INIT_LIST_HEAD(&poll->wait.entry);
+	init_waitqueue_func_entry(&poll->wait, wake_func);
+}
 
-		spin_lock_irq(&head->lock);
-		list_del_init(&poll->wait.entry);
-		if (poll->wait.private)
-			req_ref_put(req);
-		poll->head = NULL;
-		spin_unlock_irq(&head->lock);
-	}
+static inline void io_poll_remove_entry(struct io_poll_iocb *poll)
+{
+	struct wait_queue_head *head = poll->head;
+
+	spin_lock_irq(&head->lock);
+	list_del_init(&poll->wait.entry);
+	poll->head = NULL;
+	spin_unlock_irq(&head->lock);
 }
 
-static bool __io_poll_complete(struct io_kiocb *req, __poll_t mask)
-	__must_hold(&req->ctx->completion_lock)
+static void io_poll_remove_entries(struct io_kiocb *req)
+{
+	struct io_poll_iocb *poll = io_poll_get_single(req);
+	struct io_poll_iocb *poll_double = io_poll_get_double(req);
+
+	if (poll->head)
+		io_poll_remove_entry(poll);
+	if (poll_double && poll_double->head)
+		io_poll_remove_entry(poll_double);
+}
+
+/*
+ * All poll tw should go through this. Checks for poll events, manages
+ * references, does rewait, etc.
+ *
+ * Returns a negative error on failure. >0 when no action require, which is
+ * either spurious wakeup or multishot CQE is served. 0 when it's done with
+ * the request, then the mask is stored in req->result.
+ */
+static int io_poll_check_events(struct io_kiocb *req)
 {
 	struct io_ring_ctx *ctx = req->ctx;
-	unsigned flags = IORING_CQE_F_MORE;
-	int error;
+	struct io_poll_iocb *poll = io_poll_get_single(req);
+	int v;
 
-	if (READ_ONCE(req->poll.canceled)) {
-		error = -ECANCELED;
-		req->poll.events |= EPOLLONESHOT;
-	} else {
-		error = mangle_poll(mask);
-	}
-	if (req->poll.events & EPOLLONESHOT)
-		flags = 0;
-	if (!io_cqring_fill_event(ctx, req->user_data, error, flags)) {
-		req->poll.events |= EPOLLONESHOT;
-		flags = 0;
-	}
-	if (flags & IORING_CQE_F_MORE)
-		ctx->cq_extra++;
+	/* req->task == current here, checking PF_EXITING is safe */
+	if (unlikely(req->task->flags & PF_EXITING))
+		io_poll_mark_cancelled(req);
+
+	do {
+		v = atomic_read(&req->poll_refs);
+
+		/* tw handler should be the owner, and so have some references */
+		if (WARN_ON_ONCE(!(v & IO_POLL_REF_MASK)))
+			return 0;
+		if (v & IO_POLL_CANCEL_FLAG)
+			return -ECANCELED;
+
+		if (!req->result) {
+			struct poll_table_struct pt = { ._key = poll->events };
+
+			req->result = vfs_poll(req->file, &pt) & poll->events;
+		}
+
+		/* multishot, just fill an CQE and proceed */
+		if (req->result && !(poll->events & EPOLLONESHOT)) {
+			__poll_t mask = mangle_poll(req->result & poll->events);
+			bool filled;
+
+			spin_lock(&ctx->completion_lock);
+			filled = io_fill_cqe_aux(ctx, req->user_data, mask,
+						 IORING_CQE_F_MORE);
+			io_commit_cqring(ctx);
+			spin_unlock(&ctx->completion_lock);
+			if (unlikely(!filled))
+				return -ECANCELED;
+			io_cqring_ev_posted(ctx);
+		} else if (req->result) {
+			return 0;
+		}
 
-	return !(flags & IORING_CQE_F_MORE);
+		/*
+		 * Release all references, retry if someone tried to restart
+		 * task_work while we were executing it.
+		 */
+	} while (atomic_sub_return(v & IO_POLL_REF_MASK, &req->poll_refs));
+
+	return 1;
 }
 
 static void io_poll_task_func(struct io_kiocb *req, bool *locked)
 {
 	struct io_ring_ctx *ctx = req->ctx;
-	struct io_kiocb *nxt;
+	int ret;
 
-	if (io_poll_rewait(req, &req->poll)) {
-		spin_unlock(&ctx->completion_lock);
+	ret = io_poll_check_events(req);
+	if (ret > 0)
+		return;
+
+	if (!ret) {
+		req->result = mangle_poll(req->result & req->poll.events);
 	} else {
-		bool done;
+		req->result = ret;
+		req_set_fail(req);
+	}
 
-		if (req->poll.done) {
-			spin_unlock(&ctx->completion_lock);
-			return;
-		}
-		done = __io_poll_complete(req, req->result);
-		if (done) {
-			io_poll_remove_double(req);
-			hash_del(&req->hash_node);
-			req->poll.done = true;
-		} else {
-			req->result = 0;
-			add_wait_queue(req->poll.head, &req->poll.wait);
-		}
-		io_commit_cqring(ctx);
-		spin_unlock(&ctx->completion_lock);
-		io_cqring_ev_posted(ctx);
+	io_poll_remove_entries(req);
+	spin_lock(&ctx->completion_lock);
+	hash_del(&req->hash_node);
+	__io_req_complete_post(req, req->result, 0);
+	io_commit_cqring(ctx);
+	spin_unlock(&ctx->completion_lock);
+	io_cqring_ev_posted(ctx);
+}
 
-		if (done) {
-			nxt = io_put_req_find_next(req);
-			if (nxt)
-				io_req_task_submit(nxt, locked);
-		}
-	}
+static void io_apoll_task_func(struct io_kiocb *req, bool *locked)
+{
+	struct io_ring_ctx *ctx = req->ctx;
+	int ret;
+
+	ret = io_poll_check_events(req);
+	if (ret > 0)
+		return;
+
+	io_poll_remove_entries(req);
+	spin_lock(&ctx->completion_lock);
+	hash_del(&req->hash_node);
+	spin_unlock(&ctx->completion_lock);
+
+	if (!ret)
+		io_req_task_submit(req, locked);
+	else
+		io_req_complete_failed(req, ret);
 }
 
-static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
-			       int sync, void *key)
+static void __io_poll_execute(struct io_kiocb *req, int mask)
+{
+	req->result = mask;
+	if (req->opcode == IORING_OP_POLL_ADD)
+		req->io_task_work.func = io_poll_task_func;
+	else
+		req->io_task_work.func = io_apoll_task_func;
+
+	trace_io_uring_task_add(req->ctx, req->opcode, req->user_data, mask);
+	io_req_task_work_add(req, false);
+}
+
+static inline void io_poll_execute(struct io_kiocb *req, int res)
+{
+	if (io_poll_get_ownership(req))
+		__io_poll_execute(req, res);
+}
+
+static void io_poll_cancel_req(struct io_kiocb *req)
+{
+	io_poll_mark_cancelled(req);
+	/* kick tw, which should complete the request */
+	io_poll_execute(req, 0);
+}
+
+static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
+			void *key)
 {
 	struct io_kiocb *req = wait->private;
-	struct io_poll_iocb *poll = io_poll_get_single(req);
+	struct io_poll_iocb *poll = container_of(wait, struct io_poll_iocb,
+						 wait);
 	__poll_t mask = key_to_poll(key);
-	unsigned long flags;
 
-	/* for instances that support it check for an event match first: */
+	/* for instances that support it check for an event match first */
 	if (mask && !(mask & poll->events))
 		return 0;
-	if (!(poll->events & EPOLLONESHOT))
-		return poll->wait.func(&poll->wait, mode, sync, key);
-
-	list_del_init(&wait->entry);
-
-	if (poll->head) {
-		bool done;
 
-		spin_lock_irqsave(&poll->head->lock, flags);
-		done = list_empty(&poll->wait.entry);
-		if (!done)
+	if (io_poll_get_ownership(req)) {
+		/* optional, saves extra locking for removal in tw handler */
+		if (mask && poll->events & EPOLLONESHOT) {
 			list_del_init(&poll->wait.entry);
-		/* make sure double remove sees this as being gone */
-		wait->private = NULL;
-		spin_unlock_irqrestore(&poll->head->lock, flags);
-		if (!done) {
-			/* use wait func handler, so it matches the rq type */
-			poll->wait.func(&poll->wait, mode, sync, key);
+			poll->head = NULL;
 		}
+		__io_poll_execute(req, mask);
 	}
-	req_ref_put(req);
 	return 1;
 }
 
-static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events,
-			      wait_queue_func_t wake_func)
-{
-	poll->head = NULL;
-	poll->done = false;
-	poll->canceled = false;
-#define IO_POLL_UNMASK	(EPOLLERR|EPOLLHUP|EPOLLNVAL|EPOLLRDHUP)
-	/* mask in events that we always want/need */
-	poll->events = events | IO_POLL_UNMASK;
-	INIT_LIST_HEAD(&poll->wait.entry);
-	init_waitqueue_func_entry(&poll->wait, wake_func);
-}
-
 static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
 			    struct wait_queue_head *head,
 			    struct io_poll_iocb **poll_ptr)
@@ -5505,10 +5651,10 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
 	 * if this happens.
 	 */
 	if (unlikely(pt->nr_entries)) {
-		struct io_poll_iocb *poll_one = poll;
+		struct io_poll_iocb *first = poll;
 
 		/* double add on the same waitqueue head, ignore */
-		if (poll_one->head == head)
+		if (first->head == head)
 			return;
 		/* already have a 2nd entry, fail a third attempt */
 		if (*poll_ptr) {
@@ -5517,21 +5663,13 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
 			pt->error = -EINVAL;
 			return;
 		}
-		/*
-		 * Can't handle multishot for double wait for now, turn it
-		 * into one-shot mode.
-		 */
-		if (!(poll_one->events & EPOLLONESHOT))
-			poll_one->events |= EPOLLONESHOT;
+
 		poll = kmalloc(sizeof(*poll), GFP_ATOMIC);
 		if (!poll) {
 			pt->error = -ENOMEM;
 			return;
 		}
-		io_init_poll_iocb(poll, poll_one->events, io_poll_double_wake);
-		req_ref_get(req);
-		poll->wait.private = req;
-
+		io_init_poll_iocb(poll, first->events, first->wait.func);
 		*poll_ptr = poll;
 		if (req->opcode == IORING_OP_POLL_ADD)
 			req->flags |= REQ_F_ASYNC_DATA;
@@ -5539,6 +5677,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
 
 	pt->nr_entries++;
 	poll->head = head;
+	poll->wait.private = req;
 
 	if (poll->events & EPOLLEXCLUSIVE)
 		add_wait_queue_exclusive(head, &poll->wait);
@@ -5546,70 +5685,24 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
 		add_wait_queue(head, &poll->wait);
 }
 
-static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
+static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
 			       struct poll_table_struct *p)
 {
 	struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
-	struct async_poll *apoll = pt->req->apoll;
 
-	__io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
+	__io_queue_proc(&pt->req->poll, pt, head,
+			(struct io_poll_iocb **) &pt->req->async_data);
 }
 
-static void io_async_task_func(struct io_kiocb *req, bool *locked)
+static int __io_arm_poll_handler(struct io_kiocb *req,
+				 struct io_poll_iocb *poll,
+				 struct io_poll_table *ipt, __poll_t mask)
 {
-	struct async_poll *apoll = req->apoll;
 	struct io_ring_ctx *ctx = req->ctx;
-
-	trace_io_uring_task_run(req->ctx, req, req->opcode, req->user_data);
-
-	if (io_poll_rewait(req, &apoll->poll)) {
-		spin_unlock(&ctx->completion_lock);
-		return;
-	}
-
-	hash_del(&req->hash_node);
-	io_poll_remove_double(req);
-	apoll->poll.done = true;
-	spin_unlock(&ctx->completion_lock);
-
-	if (!READ_ONCE(apoll->poll.canceled))
-		io_req_task_submit(req, locked);
-	else
-		io_req_complete_failed(req, -ECANCELED);
-}
-
-static int io_async_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
-			void *key)
-{
-	struct io_kiocb *req = wait->private;
-	struct io_poll_iocb *poll = &req->apoll->poll;
-
-	trace_io_uring_poll_wake(req->ctx, req->opcode, req->user_data,
-					key_to_poll(key));
-
-	return __io_async_wake(req, poll, key_to_poll(key), io_async_task_func);
-}
-
-static void io_poll_req_insert(struct io_kiocb *req)
-{
-	struct io_ring_ctx *ctx = req->ctx;
-	struct hlist_head *list;
-
-	list = &ctx->cancel_hash[hash_long(req->user_data, ctx->cancel_hash_bits)];
-	hlist_add_head(&req->hash_node, list);
-}
-
-static __poll_t __io_arm_poll_handler(struct io_kiocb *req,
-				      struct io_poll_iocb *poll,
-				      struct io_poll_table *ipt, __poll_t mask,
-				      wait_queue_func_t wake_func)
-	__acquires(&ctx->completion_lock)
-{
-	struct io_ring_ctx *ctx = req->ctx;
-	bool cancel = false;
+	int v;
 
 	INIT_HLIST_NODE(&req->hash_node);
-	io_init_poll_iocb(poll, mask, wake_func);
+	io_init_poll_iocb(poll, mask, io_poll_wake);
 	poll->file = req->file;
 	poll->wait.private = req;
 
@@ -5618,31 +5711,54 @@ static __poll_t __io_arm_poll_handler(struct io_kiocb *req,
 	ipt->error = 0;
 	ipt->nr_entries = 0;
 
+	/*
+	 * Take the ownership to delay any tw execution up until we're done
+	 * with poll arming. see io_poll_get_ownership().
+	 */
+	atomic_set(&req->poll_refs, 1);
 	mask = vfs_poll(req->file, &ipt->pt) & poll->events;
-	if (unlikely(!ipt->nr_entries) && !ipt->error)
-		ipt->error = -EINVAL;
+
+	if (mask && (poll->events & EPOLLONESHOT)) {
+		io_poll_remove_entries(req);
+		/* no one else has access to the req, forget about the ref */
+		return mask;
+	}
+	if (!mask && unlikely(ipt->error || !ipt->nr_entries)) {
+		io_poll_remove_entries(req);
+		if (!ipt->error)
+			ipt->error = -EINVAL;
+		return 0;
+	}
 
 	spin_lock(&ctx->completion_lock);
-	if (ipt->error || (mask && (poll->events & EPOLLONESHOT)))
-		io_poll_remove_double(req);
-	if (likely(poll->head)) {
-		spin_lock_irq(&poll->head->lock);
-		if (unlikely(list_empty(&poll->wait.entry))) {
-			if (ipt->error)
-				cancel = true;
-			ipt->error = 0;
-			mask = 0;
-		}
-		if ((mask && (poll->events & EPOLLONESHOT)) || ipt->error)
-			list_del_init(&poll->wait.entry);
-		else if (cancel)
-			WRITE_ONCE(poll->canceled, true);
-		else if (!poll->done) /* actually waiting for an event */
-			io_poll_req_insert(req);
-		spin_unlock_irq(&poll->head->lock);
+	io_poll_req_insert(req);
+	spin_unlock(&ctx->completion_lock);
+
+	if (mask) {
+		/* can't multishot if failed, just queue the event we've got */
+		if (unlikely(ipt->error || !ipt->nr_entries))
+			poll->events |= EPOLLONESHOT;
+		__io_poll_execute(req, mask);
+		return 0;
 	}
 
-	return mask;
+	/*
+	 * Release ownership. If someone tried to queue a tw while it was
+	 * locked, kick it off for them.
+	 */
+	v = atomic_dec_return(&req->poll_refs);
+	if (unlikely(v & IO_POLL_REF_MASK))
+		__io_poll_execute(req, 0);
+	return 0;
+}
+
+static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
+			       struct poll_table_struct *p)
+{
+	struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
+	struct async_poll *apoll = pt->req->apoll;
+
+	__io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
 }
 
 enum {
@@ -5657,7 +5773,8 @@ static int io_arm_poll_handler(struct io_kiocb *req)
 	struct io_ring_ctx *ctx = req->ctx;
 	struct async_poll *apoll;
 	struct io_poll_table ipt;
-	__poll_t ret, mask = EPOLLONESHOT | POLLERR | POLLPRI;
+	__poll_t mask = EPOLLONESHOT | POLLERR | POLLPRI;
+	int ret;
 
 	if (!def->pollin && !def->pollout)
 		return IO_APOLL_ABORTED;
@@ -5682,11 +5799,8 @@ static int io_arm_poll_handler(struct io_kiocb *req)
 	req->apoll = apoll;
 	req->flags |= REQ_F_POLLED;
 	ipt.pt._qproc = io_async_queue_proc;
-	io_req_set_refcount(req);
 
-	ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
-					io_async_wake);
-	spin_unlock(&ctx->completion_lock);
+	ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask);
 	if (ret || ipt.error)
 		return ret ? IO_APOLL_READY : IO_APOLL_ABORTED;
 
@@ -5695,43 +5809,6 @@ static int io_arm_poll_handler(struct io_kiocb *req)
 	return IO_APOLL_OK;
 }
 
-static bool __io_poll_remove_one(struct io_kiocb *req,
-				 struct io_poll_iocb *poll, bool do_cancel)
-	__must_hold(&req->ctx->completion_lock)
-{
-	bool do_complete = false;
-
-	if (!poll->head)
-		return false;
-	spin_lock_irq(&poll->head->lock);
-	if (do_cancel)
-		WRITE_ONCE(poll->canceled, true);
-	if (!list_empty(&poll->wait.entry)) {
-		list_del_init(&poll->wait.entry);
-		do_complete = true;
-	}
-	spin_unlock_irq(&poll->head->lock);
-	hash_del(&req->hash_node);
-	return do_complete;
-}
-
-static bool io_poll_remove_one(struct io_kiocb *req)
-	__must_hold(&req->ctx->completion_lock)
-{
-	bool do_complete;
-
-	io_poll_remove_double(req);
-	do_complete = __io_poll_remove_one(req, io_poll_get_single(req), true);
-
-	if (do_complete) {
-		io_cqring_fill_event(req->ctx, req->user_data, -ECANCELED, 0);
-		io_commit_cqring(req->ctx);
-		req_set_fail(req);
-		io_put_req_deferred(req);
-	}
-	return do_complete;
-}
-
 /*
  * Returns true if we found and killed one or more poll requests
  */
@@ -5740,7 +5817,8 @@ static __cold bool io_poll_remove_all(struct io_ring_ctx *ctx,
 {
 	struct hlist_node *tmp;
 	struct io_kiocb *req;
-	int posted = 0, i;
+	bool found = false;
+	int i;
 
 	spin_lock(&ctx->completion_lock);
 	for (i = 0; i < (1U << ctx->cancel_hash_bits); i++) {
@@ -5748,16 +5826,14 @@ static __cold bool io_poll_remove_all(struct io_ring_ctx *ctx,
 
 		list = &ctx->cancel_hash[i];
 		hlist_for_each_entry_safe(req, tmp, list, hash_node) {
-			if (io_match_task_safe(req, tsk, cancel_all))
-				posted += io_poll_remove_one(req);
+			if (io_match_task_safe(req, tsk, cancel_all)) {
+				io_poll_cancel_req(req);
+				found = true;
+			}
 		}
 	}
 	spin_unlock(&ctx->completion_lock);
-
-	if (posted)
-		io_cqring_ev_posted(ctx);
-
-	return posted != 0;
+	return found;
 }
 
 static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, __u64 sqe_addr,
@@ -5778,19 +5854,26 @@ static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, __u64 sqe_addr,
 	return NULL;
 }
 
+static bool io_poll_disarm(struct io_kiocb *req)
+	__must_hold(&ctx->completion_lock)
+{
+	if (!io_poll_get_ownership(req))
+		return false;
+	io_poll_remove_entries(req);
+	hash_del(&req->hash_node);
+	return true;
+}
+
 static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr,
 			  bool poll_only)
 	__must_hold(&ctx->completion_lock)
 {
-	struct io_kiocb *req;
+	struct io_kiocb *req = io_poll_find(ctx, sqe_addr, poll_only);
 
-	req = io_poll_find(ctx, sqe_addr, poll_only);
 	if (!req)
 		return -ENOENT;
-	if (io_poll_remove_one(req))
-		return 0;
-
-	return -EALREADY;
+	io_poll_cancel_req(req);
+	return 0;
 }
 
 static __poll_t io_poll_parse_events(const struct io_uring_sqe *sqe,
@@ -5840,23 +5923,6 @@ static int io_poll_update_prep(struct io_kiocb *req,
 	return 0;
 }
 
-static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
-			void *key)
-{
-	struct io_kiocb *req = wait->private;
-	struct io_poll_iocb *poll = &req->poll;
-
-	return __io_async_wake(req, poll, key_to_poll(key), io_poll_task_func);
-}
-
-static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
-			       struct poll_table_struct *p)
-{
-	struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
-
-	__io_queue_proc(&pt->req->poll, pt, head, (struct io_poll_iocb **) &pt->req->async_data);
-}
-
 static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
 	struct io_poll_iocb *poll = &req->poll;
@@ -5869,6 +5935,8 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
 	flags = READ_ONCE(sqe->len);
 	if (flags & ~IORING_POLL_ADD_MULTI)
 		return -EINVAL;
+	if ((flags & IORING_POLL_ADD_MULTI) && (req->flags & REQ_F_CQE_SKIP))
+		return -EINVAL;
 
 	io_req_set_refcount(req);
 	poll->events = io_poll_parse_events(sqe, flags);
@@ -5878,100 +5946,60 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
 static int io_poll_add(struct io_kiocb *req, unsigned int issue_flags)
 {
 	struct io_poll_iocb *poll = &req->poll;
-	struct io_ring_ctx *ctx = req->ctx;
 	struct io_poll_table ipt;
-	__poll_t mask;
-	bool done;
+	int ret;
 
 	ipt.pt._qproc = io_poll_queue_proc;
 
-	mask = __io_arm_poll_handler(req, &req->poll, &ipt, poll->events,
-					io_poll_wake);
-
-	if (mask) { /* no async, we'd stolen it */
-		ipt.error = 0;
-		done = __io_poll_complete(req, mask);
-		io_commit_cqring(req->ctx);
-	}
-	spin_unlock(&ctx->completion_lock);
-
-	if (mask) {
-		io_cqring_ev_posted(ctx);
-		if (done)
-			io_put_req(req);
-	}
-	return ipt.error;
+	ret = __io_arm_poll_handler(req, &req->poll, &ipt, poll->events);
+	ret = ret ?: ipt.error;
+	if (ret)
+		__io_req_complete(req, issue_flags, ret, 0);
+	return 0;
 }
 
 static int io_poll_update(struct io_kiocb *req, unsigned int issue_flags)
 {
 	struct io_ring_ctx *ctx = req->ctx;
 	struct io_kiocb *preq;
-	bool completing;
-	int ret;
+	int ret2, ret = 0;
+	bool locked;
 
 	spin_lock(&ctx->completion_lock);
 	preq = io_poll_find(ctx, req->poll_update.old_user_data, true);
-	if (!preq) {
-		ret = -ENOENT;
-		goto err;
-	}
-
-	if (!req->poll_update.update_events && !req->poll_update.update_user_data) {
-		completing = true;
-		ret = io_poll_remove_one(preq) ? 0 : -EALREADY;
-		goto err;
-	}
-
-	/*
-	 * Don't allow racy completion with singleshot, as we cannot safely
-	 * update those. For multishot, if we're racing with completion, just
-	 * let completion re-add it.
-	 */
-	completing = !__io_poll_remove_one(preq, &preq->poll, false);
-	if (completing && (preq->poll.events & EPOLLONESHOT)) {
-		ret = -EALREADY;
-		goto err;
-	}
-	/* we now have a detached poll request. reissue. */
-	ret = 0;
-err:
-	if (ret < 0) {
+	if (!preq || !io_poll_disarm(preq)) {
 		spin_unlock(&ctx->completion_lock);
-		req_set_fail(req);
-		io_req_complete(req, ret);
-		return 0;
-	}
-	/* only mask one event flags, keep behavior flags */
-	if (req->poll_update.update_events) {
-		preq->poll.events &= ~0xffff;
-		preq->poll.events |= req->poll_update.events & 0xffff;
-		preq->poll.events |= IO_POLL_UNMASK;
+		ret = preq ? -EALREADY : -ENOENT;
+		goto out;
 	}
-	if (req->poll_update.update_user_data)
-		preq->user_data = req->poll_update.new_user_data;
 	spin_unlock(&ctx->completion_lock);
 
-	/* complete update request, we're done with it */
-	io_req_complete(req, ret);
-
-	if (!completing) {
-		ret = io_poll_add(preq, issue_flags);
-		if (ret < 0) {
-			req_set_fail(preq);
-			io_req_complete(preq, ret);
+	if (req->poll_update.update_events || req->poll_update.update_user_data) {
+		/* only mask one event flags, keep behavior flags */
+		if (req->poll_update.update_events) {
+			preq->poll.events &= ~0xffff;
+			preq->poll.events |= req->poll_update.events & 0xffff;
+			preq->poll.events |= IO_POLL_UNMASK;
 		}
-	}
-	return 0;
-}
+		if (req->poll_update.update_user_data)
+			preq->user_data = req->poll_update.new_user_data;
 
-static void io_req_task_timeout(struct io_kiocb *req, bool *locked)
-{
-	struct io_timeout_data *data = req->async_data;
+		ret2 = io_poll_add(preq, issue_flags);
+		/* successfully updated, don't complete poll request */
+		if (!ret2)
+			goto out;
+	}
 
-	if (!(data->flags & IORING_TIMEOUT_ETIME_SUCCESS))
+	req_set_fail(preq);
+	preq->result = -ECANCELED;
+	locked = !(issue_flags & IO_URING_F_UNLOCKED);
+	io_req_task_complete(preq, &locked);
+out:
+	if (ret < 0)
 		req_set_fail(req);
-	io_req_complete_post(req, -ETIME, 0);
+	/* complete update request, we're done with it */
+	__io_req_complete(req, issue_flags, ret, 0);
+	return 0;
 }
 
 static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
@@ -5988,8 +6016,12 @@ static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
 		atomic_read(&req->ctx->cq_timeouts) + 1);
 	spin_unlock_irqrestore(&ctx->timeout_lock, flags);
 
-	req->io_task_work.func = io_req_task_timeout;
-	io_req_task_work_add(req);
+	if (!(data->flags & IORING_TIMEOUT_ETIME_SUCCESS))
+		req_set_fail(req);
+
+	req->result = -ETIME;
+	req->io_task_work.func = io_req_task_complete;
+	io_req_task_work_add(req, false);
 	return HRTIMER_NORESTART;
 }
 
@@ -6026,7 +6058,7 @@ static int io_timeout_cancel(struct io_ring_ctx *ctx, __u64 user_data)
 		return PTR_ERR(req);
 
 	req_set_fail(req);
-	io_cqring_fill_event(ctx, req->user_data, -ECANCELED, 0);
+	io_fill_cqe_req(req, -ECANCELED, 0);
 	io_put_req_deferred(req);
 	return 0;
 }
@@ -6115,6 +6147,8 @@ static int io_timeout_remove_prep(struct io_kiocb *req,
 			return -EINVAL;
 		if (get_timespec64(&tr->ts, u64_to_user_ptr(sqe->addr2)))
 			return -EFAULT;
+		if (tr->ts.tv_sec < 0 || tr->ts.tv_nsec < 0)
+			return -EINVAL;
 	} else if (tr->flags) {
 		/* timeout removal doesn't support flags */
 		return -EINVAL;
@@ -6544,12 +6578,15 @@ static __cold void io_drain_req(struct io_kiocb *req)
 	u32 seq = io_get_sequence(req);
 
 	/* Still need defer if there is pending req in defer list. */
+	spin_lock(&ctx->completion_lock);
 	if (!req_need_defer(req, seq) && list_empty_careful(&ctx->defer_list)) {
+		spin_unlock(&ctx->completion_lock);
 queue:
 		ctx->drain_active = false;
 		io_req_task_queue(req);
 		return;
 	}
+	spin_unlock(&ctx->completion_lock);
 
 	ret = io_req_prep_async(req);
 	if (ret) {
@@ -6580,10 +6617,8 @@ fail:
 
 static void io_clean_op(struct io_kiocb *req)
 {
-	if (req->flags & REQ_F_BUFFER_SELECTED) {
-		kfree(req->kbuf);
-		req->kbuf = NULL;
-	}
+	if (req->flags & REQ_F_BUFFER_SELECTED)
+		io_put_kbuf(req);
 
 	if (req->flags & REQ_F_NEED_CLEANUP) {
 		switch (req->opcode) {
@@ -6965,7 +7000,7 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
 	spin_unlock_irqrestore(&ctx->timeout_lock, flags);
 
 	req->io_task_work.func = io_req_task_link_timeout;
-	io_req_task_work_add(req);
+	io_req_task_work_add(req, false);
 	return HRTIMER_NORESTART;
 }
 
@@ -7100,10 +7135,10 @@ static void io_init_req_drain(struct io_kiocb *req)
 		 * If we need to drain a request in the middle of a link, drain
 		 * the head request and the next request/link after the current
 		 * link. Considering sequential execution of links,
-		 * IOSQE_IO_DRAIN will be maintained for every request of our
+		 * REQ_F_IO_DRAIN will be maintained for every request of our
 		 * link.
 		 */
-		head->flags |= IOSQE_IO_DRAIN | REQ_F_FORCE_ASYNC;
+		head->flags |= REQ_F_IO_DRAIN | REQ_F_FORCE_ASYNC;
 		ctx->drain_next = true;
 	}
 }
@@ -7136,8 +7171,13 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 		if ((sqe_flags & IOSQE_BUFFER_SELECT) &&
 		    !io_op_defs[opcode].buffer_select)
 			return -EOPNOTSUPP;
-		if (sqe_flags & IOSQE_IO_DRAIN)
+		if (sqe_flags & IOSQE_CQE_SKIP_SUCCESS)
+			ctx->drain_disabled = true;
+		if (sqe_flags & IOSQE_IO_DRAIN) {
+			if (ctx->drain_disabled)
+				return -EOPNOTSUPP;
 			io_init_req_drain(req);
+		}
 	}
 	if (unlikely(ctx->restricted || ctx->drain_active || ctx->drain_next)) {
 		if (ctx->restricted && !io_check_restriction(ctx, req, sqe_flags))
@@ -7149,7 +7189,7 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 		if (unlikely(ctx->drain_next) && !ctx->submit_state.link.head) {
 			ctx->drain_next = false;
 			ctx->drain_active = true;
-			req->flags |= IOSQE_IO_DRAIN | REQ_F_FORCE_ASYNC;
+			req->flags |= REQ_F_IO_DRAIN | REQ_F_FORCE_ASYNC;
 		}
 	}
 
@@ -8263,8 +8303,7 @@ static void __io_rsrc_put_work(struct io_rsrc_node *ref_node)
 
 			io_ring_submit_lock(ctx, lock_ring);
 			spin_lock(&ctx->completion_lock);
-			io_cqring_fill_event(ctx, prsrc->tag, 0, 0);
-			ctx->cq_extra++;
+			io_fill_cqe_aux(ctx, prsrc->tag, 0, 0);
 			io_commit_cqring(ctx);
 			spin_unlock(&ctx->completion_lock);
 			io_cqring_ev_posted(ctx);
@@ -8676,6 +8715,7 @@ static __cold int io_uring_alloc_task_context(struct task_struct *task,
 	task->io_uring = tctx;
 	spin_lock_init(&tctx->task_lock);
 	INIT_WQ_LIST(&tctx->task_list);
+	INIT_WQ_LIST(&tctx->prior_task_list);
 	init_task_work(&tctx->task_work, tctx_task_work);
 	return 0;
 }
@@ -9814,18 +9854,6 @@ static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
 	return percpu_counter_sum(&tctx->inflight);
 }
 
-static __cold void io_uring_drop_tctx_refs(struct task_struct *task)
-{
-	struct io_uring_task *tctx = task->io_uring;
-	unsigned int refs = tctx->cached_refs;
-
-	if (refs) {
-		tctx->cached_refs = 0;
-		percpu_counter_sub(&tctx->inflight, refs);
-		put_task_struct_many(task, refs);
-	}
-}
-
 /*
  * Find any io_uring ctx that this task has registered or done IO on, and cancel
  * requests. @sqd should be not-null IFF it's an SQPOLL thread cancellation.
@@ -9883,10 +9911,14 @@ static __cold void io_uring_cancel_generic(bool cancel_all,
 			schedule();
 		finish_wait(&tctx->wait, &wait);
 	} while (1);
-	atomic_dec(&tctx->in_idle);
 
 	io_uring_clean_tctx(tctx);
 	if (cancel_all) {
+		/*
+		 * We shouldn't run task_works after cancel, so just leave
+		 * ->in_idle set for normal exit.
+		 */
+		atomic_dec(&tctx->in_idle);
 		/* for exec all current's requests should be gone, kill tctx */
 		__io_uring_free(current);
 	}
@@ -10164,7 +10196,7 @@ static __cold void __io_uring_show_fdinfo(struct io_ring_ctx *ctx,
 	 * and sq_tail and cq_head are changed by userspace. But it's ok since
 	 * we usually use these info when it is stuck.
 	 */
-	seq_printf(m, "SqMask:\t\t0x%x\n", sq_mask);
+	seq_printf(m, "SqMask:\t0x%x\n", sq_mask);
 	seq_printf(m, "SqHead:\t%u\n", sq_head);
 	seq_printf(m, "SqTail:\t%u\n", sq_tail);
 	seq_printf(m, "CachedSqHead:\t%u\n", ctx->cached_sq_head);
@@ -10473,7 +10505,7 @@ static __cold int io_uring_create(unsigned entries, struct io_uring_params *p,
 			IORING_FEAT_CUR_PERSONALITY | IORING_FEAT_FAST_POLL |
 			IORING_FEAT_POLL_32BITS | IORING_FEAT_SQPOLL_NONFIXED |
 			IORING_FEAT_EXT_ARG | IORING_FEAT_NATIVE_WORKERS |
-			IORING_FEAT_RSRC_TAGS;
+			IORING_FEAT_RSRC_TAGS | IORING_FEAT_CQE_SKIP;
 
 	if (copy_to_user(params, p, sizeof(*p))) {
 		ret = -EFAULT;