diff --git a/fs/io_uring.c b/fs/io_uring.c
index a7429c977eb3..b398394a919e 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -1668,7 +1668,8 @@ static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
 		WRITE_ONCE(cqe->user_data, req->user_data);
 		WRITE_ONCE(cqe->res, res);
 		WRITE_ONCE(cqe->flags, cflags);
-	} else if (ctx->cq_overflow_flushed || req->task->io_uring->in_idle) {
+	} else if (ctx->cq_overflow_flushed ||
+		   atomic_read(&req->task->io_uring->in_idle)) {
 		/*
 		 * If we're in ring overflow flush mode, or in task cancel mode,
 		 * then we cannot store the request for later flushing, we need
@@ -1838,7 +1839,7 @@ static void __io_free_req(struct io_kiocb *req)
 	io_dismantle_req(req);
 
 	percpu_counter_dec(&tctx->inflight);
-	if (tctx->in_idle)
+	if (atomic_read(&tctx->in_idle))
 		wake_up(&tctx->wait);
 	put_task_struct(req->task);
 
@@ -7695,7 +7696,8 @@ static int io_uring_alloc_task_context(struct task_struct *task)
 	xa_init(&tctx->xa);
 	init_waitqueue_head(&tctx->wait);
 	tctx->last = NULL;
-	tctx->in_idle = 0;
+	atomic_set(&tctx->in_idle, 0);
+	tctx->sqpoll = false;
 	io_init_identity(&tctx->__identity);
 	tctx->identity = &tctx->__identity;
 	task->io_uring = tctx;
@@ -8598,8 +8600,11 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 {
 	struct task_struct *task = current;
 
-	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data)
+	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
 		task = ctx->sq_data->thread;
+		atomic_inc(&task->io_uring->in_idle);
+		io_sq_thread_park(ctx->sq_data);
+	}
 
 	io_cqring_overflow_flush(ctx, true, task, files);
 
@@ -8607,12 +8612,23 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 		io_run_task_work();
 		cond_resched();
 	}
+
+	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
+		atomic_dec(&task->io_uring->in_idle);
+		/*
+		 * If the files that are going away are the ones in the thread
+		 * identity, clear them out.
+		 */
+		if (task->io_uring->identity->files == files)
+			task->io_uring->identity->files = NULL;
+		io_sq_thread_unpark(ctx->sq_data);
+	}
 }
 
 /*
  * Note that this task has used io_uring. We use it for cancelation purposes.
  */
-static int io_uring_add_task_file(struct file *file)
+static int io_uring_add_task_file(struct io_ring_ctx *ctx, struct file *file)
 {
 	struct io_uring_task *tctx = current->io_uring;
 
@@ -8634,6 +8650,14 @@ static int io_uring_add_task_file(struct file *file)
 		tctx->last = file;
 	}
 
+	/*
+	 * This is race safe in that the task itself is doing this, hence it
+	 * cannot be going through the exit/cancel paths at the same time.
+	 * This cannot be modified while exit/cancel is running.
+	 */
+	if (!tctx->sqpoll && (ctx->flags & IORING_SETUP_SQPOLL))
+		tctx->sqpoll = true;
+
 	return 0;
 }
 
@@ -8675,7 +8699,7 @@ void __io_uring_files_cancel(struct files_struct *files)
 	unsigned long index;
 
 	/* make sure overflow events are dropped */
-	tctx->in_idle = true;
+	atomic_inc(&tctx->in_idle);
 
 	xa_for_each(&tctx->xa, index, file) {
 		struct io_ring_ctx *ctx = file->private_data;
@@ -8684,6 +8708,35 @@ void __io_uring_files_cancel(struct files_struct *files)
 		if (files)
 			io_uring_del_task_file(file);
 	}
+
+	atomic_dec(&tctx->in_idle);
+}
+
+static s64 tctx_inflight(struct io_uring_task *tctx)
+{
+	unsigned long index;
+	struct file *file;
+	s64 inflight;
+
+	inflight = percpu_counter_sum(&tctx->inflight);
+	if (!tctx->sqpoll)
+		return inflight;
+
+	/*
+	 * If we have SQPOLL rings, then we need to iterate and find them, and
+	 * add the pending count for those.
+	 */
+	xa_for_each(&tctx->xa, index, file) {
+		struct io_ring_ctx *ctx = file->private_data;
+
+		if (ctx->flags & IORING_SETUP_SQPOLL) {
+			struct io_uring_task *__tctx = ctx->sqo_task->io_uring;
+
+			inflight += percpu_counter_sum(&__tctx->inflight);
+		}
+	}
+
+	return inflight;
 }
 
 /*
@@ -8697,11 +8750,11 @@ void __io_uring_task_cancel(void)
 	s64 inflight;
 
 	/* make sure overflow events are dropped */
-	tctx->in_idle = true;
+	atomic_inc(&tctx->in_idle);
 
 	do {
 		/* read completions before cancelations */
-		inflight = percpu_counter_sum(&tctx->inflight);
+		inflight = tctx_inflight(tctx);
 		if (!inflight)
 			break;
 		__io_uring_files_cancel(NULL);
@@ -8712,13 +8765,13 @@ void __io_uring_task_cancel(void)
 		 * If we've seen completions, retry. This avoids a race where
 		 * a completion comes in before we did prepare_to_wait().
 		 */
-		if (inflight != percpu_counter_sum(&tctx->inflight))
+		if (inflight != tctx_inflight(tctx))
 			continue;
 		schedule();
 	} while (1);
 
 	finish_wait(&tctx->wait, &wait);
-	tctx->in_idle = false;
+	atomic_dec(&tctx->in_idle);
 }
 
 static int io_uring_flush(struct file *file, void *data)
@@ -8863,7 +8916,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
 			io_sqpoll_wait_sq(ctx);
 		submitted = to_submit;
 	} else if (to_submit) {
-		ret = io_uring_add_task_file(f.file);
+		ret = io_uring_add_task_file(ctx, f.file);
 		if (unlikely(ret))
 			goto out;
 		mutex_lock(&ctx->uring_lock);
@@ -9092,7 +9145,7 @@ err_fd:
 #if defined(CONFIG_UNIX)
 	ctx->ring_sock->file = file;
 #endif
-	if (unlikely(io_uring_add_task_file(file))) {
+	if (unlikely(io_uring_add_task_file(ctx, file))) {
 		file = ERR_PTR(-ENOMEM);
 		goto err_fd;
 	}
diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
index 868364cea3b7..35b2d845704d 100644
--- a/include/linux/io_uring.h
+++ b/include/linux/io_uring.h
@@ -30,7 +30,8 @@ struct io_uring_task {
 	struct percpu_counter	inflight;
 	struct io_identity	__identity;
 	struct io_identity	*identity;
-	bool			in_idle;
+	atomic_t		in_idle;
+	bool			sqpoll;
 };
 
 #if defined(CONFIG_IO_URING)