summary refs log tree commit diff
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c98
1 files changed, 68 insertions, 30 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 596132a96cd5..062595ee1f83 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -166,11 +166,16 @@ static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
 			     void *key)
 {
 	struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
+	struct vhost_work *work = &poll->work;
 
 	if (!(key_to_poll(key) & poll->mask))
 		return 0;
 
-	vhost_poll_queue(poll);
+	if (!poll->dev->use_worker)
+		work->fn(work);
+	else
+		vhost_poll_queue(poll);
+
 	return 0;
 }
 
@@ -454,6 +459,7 @@ static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
 void vhost_dev_init(struct vhost_dev *dev,
 		    struct vhost_virtqueue **vqs, int nvqs,
 		    int iov_limit, int weight, int byte_weight,
+		    bool use_worker,
 		    int (*msg_handler)(struct vhost_dev *dev,
 				       struct vhost_iotlb_msg *msg))
 {
@@ -471,6 +477,7 @@ void vhost_dev_init(struct vhost_dev *dev,
 	dev->iov_limit = iov_limit;
 	dev->weight = weight;
 	dev->byte_weight = byte_weight;
+	dev->use_worker = use_worker;
 	dev->msg_handler = msg_handler;
 	init_llist_head(&dev->work_list);
 	init_waitqueue_head(&dev->wait);
@@ -534,6 +541,36 @@ bool vhost_dev_has_owner(struct vhost_dev *dev)
 }
 EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
 
+static void vhost_attach_mm(struct vhost_dev *dev)
+{
+	/* No owner, become one */
+	if (dev->use_worker) {
+		dev->mm = get_task_mm(current);
+	} else {
+		/* vDPA device does not use worker thead, so there's
+		 * no need to hold the address space for mm. This help
+		 * to avoid deadlock in the case of mmap() which may
+		 * held the refcnt of the file and depends on release
+		 * method to remove vma.
+		 */
+		dev->mm = current->mm;
+		mmgrab(dev->mm);
+	}
+}
+
+static void vhost_detach_mm(struct vhost_dev *dev)
+{
+	if (!dev->mm)
+		return;
+
+	if (dev->use_worker)
+		mmput(dev->mm);
+	else
+		mmdrop(dev->mm);
+
+	dev->mm = NULL;
+}
+
 /* Caller should have device mutex */
 long vhost_dev_set_owner(struct vhost_dev *dev)
 {
@@ -546,21 +583,24 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
 		goto err_mm;
 	}
 
-	/* No owner, become one */
-	dev->mm = get_task_mm(current);
+	vhost_attach_mm(dev);
+
 	dev->kcov_handle = kcov_common_handle();
-	worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
-	if (IS_ERR(worker)) {
-		err = PTR_ERR(worker);
-		goto err_worker;
-	}
+	if (dev->use_worker) {
+		worker = kthread_create(vhost_worker, dev,
+					"vhost-%d", current->pid);
+		if (IS_ERR(worker)) {
+			err = PTR_ERR(worker);
+			goto err_worker;
+		}
 
-	dev->worker = worker;
-	wake_up_process(worker);	/* avoid contributing to loadavg */
+		dev->worker = worker;
+		wake_up_process(worker); /* avoid contributing to loadavg */
 
-	err = vhost_attach_cgroups(dev);
-	if (err)
-		goto err_cgroup;
+		err = vhost_attach_cgroups(dev);
+		if (err)
+			goto err_cgroup;
+	}
 
 	err = vhost_dev_alloc_iovecs(dev);
 	if (err)
@@ -568,12 +608,12 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
 
 	return 0;
 err_cgroup:
-	kthread_stop(worker);
-	dev->worker = NULL;
+	if (dev->worker) {
+		kthread_stop(dev->worker);
+		dev->worker = NULL;
+	}
 err_worker:
-	if (dev->mm)
-		mmput(dev->mm);
-	dev->mm = NULL;
+	vhost_detach_mm(dev);
 	dev->kcov_handle = 0;
 err_mm:
 	return err;
@@ -670,9 +710,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
 		dev->worker = NULL;
 		dev->kcov_handle = 0;
 	}
-	if (dev->mm)
-		mmput(dev->mm);
-	dev->mm = NULL;
+	vhost_detach_mm(dev);
 }
 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
 
@@ -882,7 +920,7 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
 
 #define vhost_put_user(vq, x, ptr)		\
 ({ \
-	int ret = -EFAULT; \
+	int ret; \
 	if (!vq->iotlb) { \
 		ret = __put_user(x, ptr); \
 	} else { \
@@ -1244,9 +1282,9 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
 }
 
 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
-			 struct vring_desc __user *desc,
-			 struct vring_avail __user *avail,
-			 struct vring_used __user *used)
+			 vring_desc_t __user *desc,
+			 vring_avail_t __user *avail,
+			 vring_used_t __user *used)
 
 {
 	return access_ok(desc, vhost_get_desc_size(vq, num)) &&
@@ -1574,7 +1612,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
 			r = -EFAULT;
 			break;
 		}
-		eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
+		eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd);
 		if (IS_ERR(eventfp)) {
 			r = PTR_ERR(eventfp);
 			break;
@@ -1590,7 +1628,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
 			r = -EFAULT;
 			break;
 		}
-		ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd);
+		ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
 		if (IS_ERR(ctx)) {
 			r = PTR_ERR(ctx);
 			break;
@@ -1602,7 +1640,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
 			r = -EFAULT;
 			break;
 		}
-		ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd);
+		ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
 		if (IS_ERR(ctx)) {
 			r = PTR_ERR(ctx);
 			break;
@@ -1727,7 +1765,7 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
 		r = get_user(fd, (int __user *)argp);
 		if (r < 0)
 			break;
-		ctx = fd == -1 ? NULL : eventfd_ctx_fdget(fd);
+		ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
 		if (IS_ERR(ctx)) {
 			r = PTR_ERR(ctx);
 			break;
@@ -2300,7 +2338,7 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
 			    struct vring_used_elem *heads,
 			    unsigned count)
 {
-	struct vring_used_elem __user *used;
+	vring_used_elem_t __user *used;
 	u16 old, new;
 	int start;