diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index c6448ff37768..b60955682474 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -273,7 +273,7 @@ static void __vhost_worker_flush(struct vhost_worker *worker) { struct vhost_flush_struct flush; - if (!worker->attachment_cnt) + if (!worker->attachment_cnt || worker->killed) return; init_completion(&flush.wait_event); @@ -388,7 +388,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, __vhost_vq_meta_reset(vq); } -static bool vhost_worker(void *data) +static bool vhost_run_work_list(void *data) { struct vhost_worker *worker = data; struct vhost_work *work, *work_next; @@ -413,6 +413,40 @@ static bool vhost_worker(void *data) return !!node; } +static void vhost_worker_killed(void *data) +{ + struct vhost_worker *worker = data; + struct vhost_dev *dev = worker->dev; + struct vhost_virtqueue *vq; + int i, attach_cnt = 0; + + mutex_lock(&worker->mutex); + worker->killed = true; + + for (i = 0; i < dev->nvqs; i++) { + vq = dev->vqs[i]; + + mutex_lock(&vq->mutex); + if (worker == + rcu_dereference_check(vq->worker, + lockdep_is_held(&vq->mutex))) { + rcu_assign_pointer(vq->worker, NULL); + attach_cnt++; + } + mutex_unlock(&vq->mutex); + } + + worker->attachment_cnt -= attach_cnt; + if (attach_cnt) + synchronize_rcu(); + /* + * Finish vhost_worker_flush calls and any other works that snuck in + * before the synchronize_rcu. + */ + vhost_run_work_list(worker); + mutex_unlock(&worker->mutex); +} + static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) { kfree(vq->indirect); @@ -627,9 +661,11 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev) if (!worker) return NULL; + worker->dev = dev; snprintf(name, sizeof(name), "vhost-%d", current->pid); - vtsk = vhost_task_create(vhost_worker, worker, name); + vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed, + worker, name); if (!vtsk) goto free_worker; @@ -661,6 +697,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, struct vhost_worker *old_worker; mutex_lock(&worker->mutex); + if (worker->killed) { + mutex_unlock(&worker->mutex); + return; + } + mutex_lock(&vq->mutex); old_worker = rcu_dereference_check(vq->worker, @@ -681,6 +722,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, * device wide flushes which doesn't use RCU for execution. */ mutex_lock(&old_worker->mutex); + if (old_worker->killed) { + mutex_unlock(&old_worker->mutex); + return; + } + /* * We don't want to call synchronize_rcu for every vq during setup * because it will slow down VM startup. If we haven't done @@ -758,7 +804,7 @@ static int vhost_free_worker(struct vhost_dev *dev, return -ENODEV; mutex_lock(&worker->mutex); - if (worker->attachment_cnt) { + if (worker->attachment_cnt || worker->killed) { mutex_unlock(&worker->mutex); return -EBUSY; } diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 91ade037f08e..bb75a292d50c 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -28,12 +28,14 @@ struct vhost_work { struct vhost_worker { struct vhost_task *vtsk; + struct vhost_dev *dev; /* Used to serialize device wide flushing with worker swapping. */ struct mutex mutex; struct llist_head work_list; u64 kcov_handle; u32 id; int attachment_cnt; + bool killed; }; /* Poll a file (eventfd or socket) */ diff --git a/include/linux/sched/vhost_task.h b/include/linux/sched/vhost_task.h index bc60243d43b3..25446c5d3508 100644 --- a/include/linux/sched/vhost_task.h +++ b/include/linux/sched/vhost_task.h @@ -4,7 +4,8 @@ struct vhost_task; -struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg, +struct vhost_task *vhost_task_create(bool (*fn)(void *), + void (*handle_kill)(void *), void *arg, const char *name); void vhost_task_start(struct vhost_task *vtsk); void vhost_task_stop(struct vhost_task *vtsk); diff --git a/kernel/vhost_task.c b/kernel/vhost_task.c index da35e5b7f047..8800f5acc007 100644 --- a/kernel/vhost_task.c +++ b/kernel/vhost_task.c @@ -10,38 +10,32 @@ enum vhost_task_flags { VHOST_TASK_FLAGS_STOP, + VHOST_TASK_FLAGS_KILLED, }; struct vhost_task { bool (*fn)(void *data); + void (*handle_sigkill)(void *data); void *data; struct completion exited; unsigned long flags; struct task_struct *task; + /* serialize SIGKILL and vhost_task_stop calls */ + struct mutex exit_mutex; }; static int vhost_task_fn(void *data) { struct vhost_task *vtsk = data; - bool dead = false; for (;;) { bool did_work; - if (!dead && signal_pending(current)) { + if (signal_pending(current)) { struct ksignal ksig; - /* - * Calling get_signal will block in SIGSTOP, - * or clear fatal_signal_pending, but remember - * what was set. - * - * This thread won't actually exit until all - * of the file descriptors are closed, and - * the release function is called. - */ - dead = get_signal(&ksig); - if (dead) - clear_thread_flag(TIF_SIGPENDING); + + if (get_signal(&ksig)) + break; } /* mb paired w/ vhost_task_stop */ @@ -57,7 +51,19 @@ static int vhost_task_fn(void *data) schedule(); } + mutex_lock(&vtsk->exit_mutex); + /* + * If a vhost_task_stop and SIGKILL race, we can ignore the SIGKILL. + * When the vhost layer has called vhost_task_stop it's already stopped + * new work and flushed. + */ + if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) { + set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags); + vtsk->handle_sigkill(vtsk->data); + } + mutex_unlock(&vtsk->exit_mutex); complete(&vtsk->exited); + do_exit(0); } @@ -78,12 +84,17 @@ EXPORT_SYMBOL_GPL(vhost_task_wake); * @vtsk: vhost_task to stop * * vhost_task_fn ensures the worker thread exits after - * VHOST_TASK_FLAGS_SOP becomes true. + * VHOST_TASK_FLAGS_STOP becomes true. */ void vhost_task_stop(struct vhost_task *vtsk) { - set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags); - vhost_task_wake(vtsk); + mutex_lock(&vtsk->exit_mutex); + if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) { + set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags); + vhost_task_wake(vtsk); + } + mutex_unlock(&vtsk->exit_mutex); + /* * Make sure vhost_task_fn is no longer accessing the vhost_task before * freeing it below. @@ -96,14 +107,16 @@ EXPORT_SYMBOL_GPL(vhost_task_stop); /** * vhost_task_create - create a copy of a task to be used by the kernel * @fn: vhost worker function - * @arg: data to be passed to fn + * @handle_sigkill: vhost function to handle when we are killed + * @arg: data to be passed to fn and handled_kill * @name: the thread's name * * This returns a specialized task for use by the vhost layer or NULL on * failure. The returned task is inactive, and the caller must fire it up * through vhost_task_start(). */ -struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg, +struct vhost_task *vhost_task_create(bool (*fn)(void *), + void (*handle_sigkill)(void *), void *arg, const char *name) { struct kernel_clone_args args = { @@ -122,8 +135,10 @@ struct vhost_task *vhost_task_create(bool (*fn)(void *), void *arg, if (!vtsk) return NULL; init_completion(&vtsk->exited); + mutex_init(&vtsk->exit_mutex); vtsk->data = arg; vtsk->fn = fn; + vtsk->handle_sigkill = handle_sigkill; args.fn_arg = vtsk;