summary refs log tree commit diff
path: root/drivers/vfio
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2021-02-24 10:43:40 -0800
committerLinus Torvalds <torvalds@linux-foundation.org>2021-02-24 10:43:40 -0800
commit719bbd4a509f403f537adcaefd8ce17532be2e84 (patch)
tree660fb01513c219899d2a5034587dd8e6ac0bf0c6 /drivers/vfio
parentc4fbde84fedeaf513ec96f0c6ed3f352bdcd61d6 (diff)
parent4d83de6da265cd84e74c19d876055fa5f261cde4 (diff)
downloadlinux-719bbd4a509f403f537adcaefd8ce17532be2e84.tar.gz
Merge tag 'vfio-v5.12-rc1' of git://github.com/awilliam/linux-vfio
Pull VFIO updatesfrom Alex Williamson:

 - Virtual address update handling (Steve Sistare)

 - s390/zpci fixes and cleanups (Max Gurtovoy)

 - Fixes for dirty bitmap handling, non-mdev page pinning, and improved
   pinned dirty scope tracking (Keqian Zhu)

 - Batched page pinning enhancement (Daniel Jordan)

 - Page access permission fix (Alex Williamson)

* tag 'vfio-v5.12-rc1' of git://github.com/awilliam/linux-vfio: (21 commits)
  vfio/type1: Batch page pinning
  vfio/type1: Prepare for batched pinning with struct vfio_batch
  vfio/type1: Change success value of vaddr_get_pfn()
  vfio/type1: Use follow_pte()
  vfio/pci: remove CONFIG_VFIO_PCI_ZDEV from Kconfig
  vfio/iommu_type1: Fix duplicate included kthread.h
  vfio-pci/zdev: fix possible segmentation fault issue
  vfio-pci/zdev: remove unused vdev argument
  vfio/pci: Fix handling of pci use accessor return codes
  vfio/iommu_type1: Mantain a counter for non_pinned_groups
  vfio/iommu_type1: Fix some sanity checks in detach group
  vfio/iommu_type1: Populate full dirty when detach non-pinned group
  vfio/type1: block on invalid vaddr
  vfio/type1: implement notify callback
  vfio: iommu driver notify callback
  vfio/type1: implement interfaces to update vaddr
  vfio/type1: massage unmap iteration
  vfio: interfaces to update vaddr
  vfio/type1: implement unmap all
  vfio/type1: unmap cleanup
  ...
Diffstat (limited to 'drivers/vfio')
-rw-r--r--drivers/vfio/pci/Kconfig12
-rw-r--r--drivers/vfio/pci/Makefile2
-rw-r--r--drivers/vfio/pci/vfio_pci.c12
-rw-r--r--drivers/vfio/pci/vfio_pci_igd.c10
-rw-r--r--drivers/vfio/pci/vfio_pci_private.h2
-rw-r--r--drivers/vfio/pci/vfio_pci_zdev.c24
-rw-r--r--drivers/vfio/vfio.c5
-rw-r--r--drivers/vfio/vfio_iommu_type1.c564
8 files changed, 441 insertions, 190 deletions
diff --git a/drivers/vfio/pci/Kconfig b/drivers/vfio/pci/Kconfig
index 40a223381ab6..ac3c1dd3edef 100644
--- a/drivers/vfio/pci/Kconfig
+++ b/drivers/vfio/pci/Kconfig
@@ -45,15 +45,3 @@ config VFIO_PCI_NVLINK2
 	depends on VFIO_PCI && PPC_POWERNV
 	help
 	  VFIO PCI support for P9 Witherspoon machine with NVIDIA V100 GPUs
-
-config VFIO_PCI_ZDEV
-	bool "VFIO PCI ZPCI device CLP support"
-	depends on VFIO_PCI && S390
-	default y
-	help
-	  Enabling this option exposes VFIO capabilities containing hardware
-	  configuration for zPCI devices. This enables userspace (e.g. QEMU)
-	  to supply proper configuration values instead of hard-coded defaults
-	  for zPCI devices passed through via VFIO on s390.
-
-	  Say Y here.
diff --git a/drivers/vfio/pci/Makefile b/drivers/vfio/pci/Makefile
index 781e0809d6ee..eff97a7cd9f1 100644
--- a/drivers/vfio/pci/Makefile
+++ b/drivers/vfio/pci/Makefile
@@ -3,6 +3,6 @@
 vfio-pci-y := vfio_pci.o vfio_pci_intrs.o vfio_pci_rdwr.o vfio_pci_config.o
 vfio-pci-$(CONFIG_VFIO_PCI_IGD) += vfio_pci_igd.o
 vfio-pci-$(CONFIG_VFIO_PCI_NVLINK2) += vfio_pci_nvlink2.o
-vfio-pci-$(CONFIG_VFIO_PCI_ZDEV) += vfio_pci_zdev.o
+vfio-pci-$(CONFIG_S390) += vfio_pci_zdev.o
 
 obj-$(CONFIG_VFIO_PCI) += vfio-pci.o
diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 706de3ef94bb..65e7e6b44578 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -807,6 +807,7 @@ static long vfio_pci_ioctl(void *device_data,
 		struct vfio_device_info info;
 		struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
 		unsigned long capsz;
+		int ret;
 
 		minsz = offsetofend(struct vfio_device_info, num_irqs);
 
@@ -832,13 +833,10 @@ static long vfio_pci_ioctl(void *device_data,
 		info.num_regions = VFIO_PCI_NUM_REGIONS + vdev->num_regions;
 		info.num_irqs = VFIO_PCI_NUM_IRQS;
 
-		if (IS_ENABLED(CONFIG_VFIO_PCI_ZDEV)) {
-			int ret = vfio_pci_info_zdev_add_caps(vdev, &caps);
-
-			if (ret && ret != -ENODEV) {
-				pci_warn(vdev->pdev, "Failed to setup zPCI info capabilities\n");
-				return ret;
-			}
+		ret = vfio_pci_info_zdev_add_caps(vdev, &caps);
+		if (ret && ret != -ENODEV) {
+			pci_warn(vdev->pdev, "Failed to setup zPCI info capabilities\n");
+			return ret;
 		}
 
 		if (caps.size) {
diff --git a/drivers/vfio/pci/vfio_pci_igd.c b/drivers/vfio/pci/vfio_pci_igd.c
index 53d97f459252..e66dfb0178ed 100644
--- a/drivers/vfio/pci/vfio_pci_igd.c
+++ b/drivers/vfio/pci/vfio_pci_igd.c
@@ -127,7 +127,7 @@ static size_t vfio_pci_igd_cfg_rw(struct vfio_pci_device *vdev,
 
 		ret = pci_user_read_config_byte(pdev, pos, &val);
 		if (ret)
-			return pcibios_err_to_errno(ret);
+			return ret;
 
 		if (copy_to_user(buf + count - size, &val, 1))
 			return -EFAULT;
@@ -141,7 +141,7 @@ static size_t vfio_pci_igd_cfg_rw(struct vfio_pci_device *vdev,
 
 		ret = pci_user_read_config_word(pdev, pos, &val);
 		if (ret)
-			return pcibios_err_to_errno(ret);
+			return ret;
 
 		val = cpu_to_le16(val);
 		if (copy_to_user(buf + count - size, &val, 2))
@@ -156,7 +156,7 @@ static size_t vfio_pci_igd_cfg_rw(struct vfio_pci_device *vdev,
 
 		ret = pci_user_read_config_dword(pdev, pos, &val);
 		if (ret)
-			return pcibios_err_to_errno(ret);
+			return ret;
 
 		val = cpu_to_le32(val);
 		if (copy_to_user(buf + count - size, &val, 4))
@@ -171,7 +171,7 @@ static size_t vfio_pci_igd_cfg_rw(struct vfio_pci_device *vdev,
 
 		ret = pci_user_read_config_word(pdev, pos, &val);
 		if (ret)
-			return pcibios_err_to_errno(ret);
+			return ret;
 
 		val = cpu_to_le16(val);
 		if (copy_to_user(buf + count - size, &val, 2))
@@ -186,7 +186,7 @@ static size_t vfio_pci_igd_cfg_rw(struct vfio_pci_device *vdev,
 
 		ret = pci_user_read_config_byte(pdev, pos, &val);
 		if (ret)
-			return pcibios_err_to_errno(ret);
+			return ret;
 
 		if (copy_to_user(buf + count - size, &val, 1))
 			return -EFAULT;
diff --git a/drivers/vfio/pci/vfio_pci_private.h b/drivers/vfio/pci/vfio_pci_private.h
index 5c90e560c5c7..9cd1882a05af 100644
--- a/drivers/vfio/pci/vfio_pci_private.h
+++ b/drivers/vfio/pci/vfio_pci_private.h
@@ -214,7 +214,7 @@ static inline int vfio_pci_ibm_npu2_init(struct vfio_pci_device *vdev)
 }
 #endif
 
-#ifdef CONFIG_VFIO_PCI_ZDEV
+#ifdef CONFIG_S390
 extern int vfio_pci_info_zdev_add_caps(struct vfio_pci_device *vdev,
 				       struct vfio_info_cap *caps);
 #else
diff --git a/drivers/vfio/pci/vfio_pci_zdev.c b/drivers/vfio/pci/vfio_pci_zdev.c
index 229685634031..7b011b62c766 100644
--- a/drivers/vfio/pci/vfio_pci_zdev.c
+++ b/drivers/vfio/pci/vfio_pci_zdev.c
@@ -24,8 +24,7 @@
 /*
  * Add the Base PCI Function information to the device info region.
  */
-static int zpci_base_cap(struct zpci_dev *zdev, struct vfio_pci_device *vdev,
-			 struct vfio_info_cap *caps)
+static int zpci_base_cap(struct zpci_dev *zdev, struct vfio_info_cap *caps)
 {
 	struct vfio_device_info_cap_zpci_base cap = {
 		.header.id = VFIO_DEVICE_INFO_CAP_ZPCI_BASE,
@@ -45,8 +44,7 @@ static int zpci_base_cap(struct zpci_dev *zdev, struct vfio_pci_device *vdev,
 /*
  * Add the Base PCI Function Group information to the device info region.
  */
-static int zpci_group_cap(struct zpci_dev *zdev, struct vfio_pci_device *vdev,
-			  struct vfio_info_cap *caps)
+static int zpci_group_cap(struct zpci_dev *zdev, struct vfio_info_cap *caps)
 {
 	struct vfio_device_info_cap_zpci_group cap = {
 		.header.id = VFIO_DEVICE_INFO_CAP_ZPCI_GROUP,
@@ -66,14 +64,15 @@ static int zpci_group_cap(struct zpci_dev *zdev, struct vfio_pci_device *vdev,
 /*
  * Add the device utility string to the device info region.
  */
-static int zpci_util_cap(struct zpci_dev *zdev, struct vfio_pci_device *vdev,
-			 struct vfio_info_cap *caps)
+static int zpci_util_cap(struct zpci_dev *zdev, struct vfio_info_cap *caps)
 {
 	struct vfio_device_info_cap_zpci_util *cap;
 	int cap_size = sizeof(*cap) + CLP_UTIL_STR_LEN;
 	int ret;
 
 	cap = kmalloc(cap_size, GFP_KERNEL);
+	if (!cap)
+		return -ENOMEM;
 
 	cap->header.id = VFIO_DEVICE_INFO_CAP_ZPCI_UTIL;
 	cap->header.version = 1;
@@ -90,14 +89,15 @@ static int zpci_util_cap(struct zpci_dev *zdev, struct vfio_pci_device *vdev,
 /*
  * Add the function path string to the device info region.
  */
-static int zpci_pfip_cap(struct zpci_dev *zdev, struct vfio_pci_device *vdev,
-			 struct vfio_info_cap *caps)
+static int zpci_pfip_cap(struct zpci_dev *zdev, struct vfio_info_cap *caps)
 {
 	struct vfio_device_info_cap_zpci_pfip *cap;
 	int cap_size = sizeof(*cap) + CLP_PFIP_NR_SEGMENTS;
 	int ret;
 
 	cap = kmalloc(cap_size, GFP_KERNEL);
+	if (!cap)
+		return -ENOMEM;
 
 	cap->header.id = VFIO_DEVICE_INFO_CAP_ZPCI_PFIP;
 	cap->header.version = 1;
@@ -123,21 +123,21 @@ int vfio_pci_info_zdev_add_caps(struct vfio_pci_device *vdev,
 	if (!zdev)
 		return -ENODEV;
 
-	ret = zpci_base_cap(zdev, vdev, caps);
+	ret = zpci_base_cap(zdev, caps);
 	if (ret)
 		return ret;
 
-	ret = zpci_group_cap(zdev, vdev, caps);
+	ret = zpci_group_cap(zdev, caps);
 	if (ret)
 		return ret;
 
 	if (zdev->util_str_avail) {
-		ret = zpci_util_cap(zdev, vdev, caps);
+		ret = zpci_util_cap(zdev, caps);
 		if (ret)
 			return ret;
 	}
 
-	ret = zpci_pfip_cap(zdev, vdev, caps);
+	ret = zpci_pfip_cap(zdev, caps);
 
 	return ret;
 }
diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c
index 4ad8a35667a7..38779e6fd80c 100644
--- a/drivers/vfio/vfio.c
+++ b/drivers/vfio/vfio.c
@@ -1220,6 +1220,11 @@ static int vfio_fops_open(struct inode *inode, struct file *filep)
 static int vfio_fops_release(struct inode *inode, struct file *filep)
 {
 	struct vfio_container *container = filep->private_data;
+	struct vfio_iommu_driver *driver = container->iommu_driver;
+
+	if (driver && driver->ops->notify)
+		driver->ops->notify(container->iommu_data,
+				    VFIO_IOMMU_CONTAINER_CLOSE);
 
 	filep->private_data = NULL;
 
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 0b4dedaa9128..4bb162c1d649 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -24,6 +24,7 @@
 #include <linux/compat.h>
 #include <linux/device.h>
 #include <linux/fs.h>
+#include <linux/highmem.h>
 #include <linux/iommu.h>
 #include <linux/module.h>
 #include <linux/mm.h>
@@ -69,11 +70,15 @@ struct vfio_iommu {
 	struct rb_root		dma_list;
 	struct blocking_notifier_head notifier;
 	unsigned int		dma_avail;
+	unsigned int		vaddr_invalid_count;
 	uint64_t		pgsize_bitmap;
+	uint64_t		num_non_pinned_groups;
+	wait_queue_head_t	vaddr_wait;
 	bool			v2;
 	bool			nesting;
 	bool			dirty_page_tracking;
 	bool			pinned_page_dirty_scope;
+	bool			container_open;
 };
 
 struct vfio_domain {
@@ -92,11 +97,20 @@ struct vfio_dma {
 	int			prot;		/* IOMMU_READ/WRITE */
 	bool			iommu_mapped;
 	bool			lock_cap;	/* capable(CAP_IPC_LOCK) */
+	bool			vaddr_invalid;
 	struct task_struct	*task;
 	struct rb_root		pfn_list;	/* Ex-user pinned pfn list */
 	unsigned long		*bitmap;
 };
 
+struct vfio_batch {
+	struct page		**pages;	/* for pin_user_pages_remote */
+	struct page		*fallback_page; /* if pages alloc fails */
+	int			capacity;	/* length of pages array */
+	int			size;		/* of batch currently */
+	int			offset;		/* of next entry in pages */
+};
+
 struct vfio_group {
 	struct iommu_group	*iommu_group;
 	struct list_head	next;
@@ -143,12 +157,13 @@ struct vfio_regions {
 #define DIRTY_BITMAP_PAGES_MAX	 ((u64)INT_MAX)
 #define DIRTY_BITMAP_SIZE_MAX	 DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
 
+#define WAITED 1
+
 static int put_pfn(unsigned long pfn, int prot);
 
 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
 					       struct iommu_group *iommu_group);
 
-static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu);
 /*
  * This code handles mapping and unmapping of user data buffers
  * into DMA'ble space using the IOMMU
@@ -173,6 +188,31 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
 	return NULL;
 }
 
+static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
+						dma_addr_t start, size_t size)
+{
+	struct rb_node *res = NULL;
+	struct rb_node *node = iommu->dma_list.rb_node;
+	struct vfio_dma *dma_res = NULL;
+
+	while (node) {
+		struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
+
+		if (start < dma->iova + dma->size) {
+			res = node;
+			dma_res = dma;
+			if (start >= dma->iova)
+				break;
+			node = node->rb_left;
+		} else {
+			node = node->rb_right;
+		}
+	}
+	if (res && size && dma_res->iova >= start + size)
+		res = NULL;
+	return res;
+}
+
 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
 {
 	struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
@@ -236,6 +276,18 @@ static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
 	}
 }
 
+static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
+{
+	struct rb_node *n;
+	unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
+
+	for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
+		struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
+
+		bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
+	}
+}
+
 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
 {
 	struct rb_node *n;
@@ -415,13 +467,54 @@ static int put_pfn(unsigned long pfn, int prot)
 	return 0;
 }
 
+#define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
+
+static void vfio_batch_init(struct vfio_batch *batch)
+{
+	batch->size = 0;
+	batch->offset = 0;
+
+	if (unlikely(disable_hugepages))
+		goto fallback;
+
+	batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
+	if (!batch->pages)
+		goto fallback;
+
+	batch->capacity = VFIO_BATCH_MAX_CAPACITY;
+	return;
+
+fallback:
+	batch->pages = &batch->fallback_page;
+	batch->capacity = 1;
+}
+
+static void vfio_batch_unpin(struct vfio_batch *batch, struct vfio_dma *dma)
+{
+	while (batch->size) {
+		unsigned long pfn = page_to_pfn(batch->pages[batch->offset]);
+
+		put_pfn(pfn, dma->prot);
+		batch->offset++;
+		batch->size--;
+	}
+}
+
+static void vfio_batch_fini(struct vfio_batch *batch)
+{
+	if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
+		free_page((unsigned long)batch->pages);
+}
+
 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
 			    unsigned long vaddr, unsigned long *pfn,
 			    bool write_fault)
 {
+	pte_t *ptep;
+	spinlock_t *ptl;
 	int ret;
 
-	ret = follow_pfn(vma, vaddr, pfn);
+	ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
 	if (ret) {
 		bool unlocked = false;
 
@@ -435,16 +528,28 @@ static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
 		if (ret)
 			return ret;
 
-		ret = follow_pfn(vma, vaddr, pfn);
+		ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
+		if (ret)
+			return ret;
 	}
 
+	if (write_fault && !pte_write(*ptep))
+		ret = -EFAULT;
+	else
+		*pfn = pte_pfn(*ptep);
+
+	pte_unmap_unlock(ptep, ptl);
 	return ret;
 }
 
-static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
-			 int prot, unsigned long *pfn)
+/*
+ * Returns the positive number of pfns successfully obtained or a negative
+ * error code.
+ */
+static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
+			  long npages, int prot, unsigned long *pfn,
+			  struct page **pages)
 {
-	struct page *page[1];
 	struct vm_area_struct *vma;
 	unsigned int flags = 0;
 	int ret;
@@ -453,11 +558,10 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
 		flags |= FOLL_WRITE;
 
 	mmap_read_lock(mm);
-	ret = pin_user_pages_remote(mm, vaddr, 1, flags | FOLL_LONGTERM,
-				    page, NULL, NULL);
-	if (ret == 1) {
-		*pfn = page_to_pfn(page[0]);
-		ret = 0;
+	ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
+				    pages, NULL, NULL);
+	if (ret > 0) {
+		*pfn = page_to_pfn(pages[0]);
 		goto done;
 	}
 
@@ -471,14 +575,73 @@ retry:
 		if (ret == -EAGAIN)
 			goto retry;
 
-		if (!ret && !is_invalid_reserved_pfn(*pfn))
-			ret = -EFAULT;
+		if (!ret) {
+			if (is_invalid_reserved_pfn(*pfn))
+				ret = 1;
+			else
+				ret = -EFAULT;
+		}
 	}
 done:
 	mmap_read_unlock(mm);
 	return ret;
 }
 
+static int vfio_wait(struct vfio_iommu *iommu)
+{
+	DEFINE_WAIT(wait);
+
+	prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
+	mutex_unlock(&iommu->lock);
+	schedule();
+	mutex_lock(&iommu->lock);
+	finish_wait(&iommu->vaddr_wait, &wait);
+	if (kthread_should_stop() || !iommu->container_open ||
+	    fatal_signal_pending(current)) {
+		return -EFAULT;
+	}
+	return WAITED;
+}
+
+/*
+ * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
+ * if the task waits, but is re-locked on return.  Return result in *dma_p.
+ * Return 0 on success with no waiting, WAITED on success if waited, and -errno
+ * on error.
+ */
+static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
+			       size_t size, struct vfio_dma **dma_p)
+{
+	int ret;
+
+	do {
+		*dma_p = vfio_find_dma(iommu, start, size);
+		if (!*dma_p)
+			ret = -EINVAL;
+		else if (!(*dma_p)->vaddr_invalid)
+			ret = 0;
+		else
+			ret = vfio_wait(iommu);
+	} while (ret > 0);
+
+	return ret;
+}
+
+/*
+ * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
+ * if the task waits, but is re-locked on return.  Return 0 on success with no
+ * waiting, WAITED on success if waited, and -errno on error.
+ */
+static int vfio_wait_all_valid(struct vfio_iommu *iommu)
+{
+	int ret = 0;
+
+	while (iommu->vaddr_invalid_count && ret >= 0)
+		ret = vfio_wait(iommu);
+
+	return ret;
+}
+
 /*
  * Attempt to pin pages.  We really don't want to track all the pfns and
  * the iommu can only map chunks of consecutive pfns anyway, so get the
@@ -486,76 +649,102 @@ done:
  */
 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 				  long npage, unsigned long *pfn_base,
-				  unsigned long limit)
+				  unsigned long limit, struct vfio_batch *batch)
 {
-	unsigned long pfn = 0;
+	unsigned long pfn;
+	struct mm_struct *mm = current->mm;
 	long ret, pinned = 0, lock_acct = 0;
 	bool rsvd;
 	dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
 
 	/* This code path is only user initiated */
-	if (!current->mm)
+	if (!mm)
 		return -ENODEV;
 
-	ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
-	if (ret)
-		return ret;
-
-	pinned++;
-	rsvd = is_invalid_reserved_pfn(*pfn_base);
-
-	/*
-	 * Reserved pages aren't counted against the user, externally pinned
-	 * pages are already counted against the user.
-	 */
-	if (!rsvd && !vfio_find_vpfn(dma, iova)) {
-		if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
-			put_pfn(*pfn_base, dma->prot);
-			pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
-					limit << PAGE_SHIFT);
-			return -ENOMEM;
-		}
-		lock_acct++;
+	if (batch->size) {
+		/* Leftover pages in batch from an earlier call. */
+		*pfn_base = page_to_pfn(batch->pages[batch->offset]);
+		pfn = *pfn_base;
+		rsvd = is_invalid_reserved_pfn(*pfn_base);
+	} else {
+		*pfn_base = 0;
 	}
 
-	if (unlikely(disable_hugepages))
-		goto out;
+	while (npage) {
+		if (!batch->size) {
+			/* Empty batch, so refill it. */
+			long req_pages = min_t(long, npage, batch->capacity);
 
-	/* Lock all the consecutive pages from pfn_base */
-	for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
-	     pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
-		ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
-		if (ret)
-			break;
+			ret = vaddr_get_pfns(mm, vaddr, req_pages, dma->prot,
+					     &pfn, batch->pages);
+			if (ret < 0)
+				goto unpin_out;
 
-		if (pfn != *pfn_base + pinned ||
-		    rsvd != is_invalid_reserved_pfn(pfn)) {
-			put_pfn(pfn, dma->prot);
-			break;
+			batch->size = ret;
+			batch->offset = 0;
+
+			if (!*pfn_base) {
+				*pfn_base = pfn;
+				rsvd = is_invalid_reserved_pfn(*pfn_base);
+			}
 		}
 
-		if (!rsvd && !vfio_find_vpfn(dma, iova)) {
-			if (!dma->lock_cap &&
-			    current->mm->locked_vm + lock_acct + 1 > limit) {
-				put_pfn(pfn, dma->prot);
-				pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
-					__func__, limit << PAGE_SHIFT);
-				ret = -ENOMEM;
-				goto unpin_out;
+		/*
+		 * pfn is preset for the first iteration of this inner loop and
+		 * updated at the end to handle a VM_PFNMAP pfn.  In that case,
+		 * batch->pages isn't valid (there's no struct page), so allow
+		 * batch->pages to be touched only when there's more than one
+		 * pfn to check, which guarantees the pfns are from a
+		 * !VM_PFNMAP vma.
+		 */
+		while (true) {
+			if (pfn != *pfn_base + pinned ||
+			    rsvd != is_invalid_reserved_pfn(pfn))
+				goto out;
+
+			/*
+			 * Reserved pages aren't counted against the user,
+			 * externally pinned pages are already counted against
+			 * the user.
+			 */
+			if (!rsvd && !vfio_find_vpfn(dma, iova)) {
+				if (!dma->lock_cap &&
+				    mm->locked_vm + lock_acct + 1 > limit) {
+					pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
+						__func__, limit << PAGE_SHIFT);
+					ret = -ENOMEM;
+					goto unpin_out;
+				}
+				lock_acct++;
 			}
-			lock_acct++;
+
+			pinned++;
+			npage--;
+			vaddr += PAGE_SIZE;
+			iova += PAGE_SIZE;
+			batch->offset++;
+			batch->size--;
+
+			if (!batch->size)
+				break;
+
+			pfn = page_to_pfn(batch->pages[batch->offset]);
 		}
+
+		if (unlikely(disable_hugepages))
+			break;
 	}
 
 out:
 	ret = vfio_lock_acct(dma, lock_acct, false);
 
 unpin_out:
-	if (ret) {
-		if (!rsvd) {
+	if (ret < 0) {
+		if (pinned && !rsvd) {
 			for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
 				put_pfn(pfn, dma->prot);
 		}
+		vfio_batch_unpin(batch, dma);
 
 		return ret;
 	}
@@ -587,6 +776,7 @@ static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
 				  unsigned long *pfn_base, bool do_accounting)
 {
+	struct page *pages[1];
 	struct mm_struct *mm;
 	int ret;
 
@@ -594,8 +784,8 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
 	if (!mm)
 		return -ENODEV;
 
-	ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
-	if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
+	ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
+	if (ret == 1 && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
 		ret = vfio_lock_acct(dma, 1, true);
 		if (ret) {
 			put_pfn(*pfn_base, dma->prot);
@@ -640,6 +830,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
 	unsigned long remote_vaddr;
 	struct vfio_dma *dma;
 	bool do_accounting;
+	dma_addr_t iova;
 
 	if (!iommu || !user_pfn || !phys_pfn)
 		return -EINVAL;
@@ -650,6 +841,22 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
 
 	mutex_lock(&iommu->lock);
 
+	/*
+	 * Wait for all necessary vaddr's to be valid so they can be used in
+	 * the main loop without dropping the lock, to avoid racing vs unmap.
+	 */
+again:
+	if (iommu->vaddr_invalid_count) {
+		for (i = 0; i < npage; i++) {
+			iova = user_pfn[i] << PAGE_SHIFT;
+			ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
+			if (ret < 0)
+				goto pin_done;
+			if (ret == WAITED)
+				goto again;
+		}
+	}
+
 	/* Fail if notifier list is empty */
 	if (!iommu->notifier.head) {
 		ret = -EINVAL;
@@ -664,7 +871,6 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
 	do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
 
 	for (i = 0; i < npage; i++) {
-		dma_addr_t iova;
 		struct vfio_pfn *vpfn;
 
 		iova = user_pfn[i] << PAGE_SHIFT;
@@ -714,7 +920,7 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
 	group = vfio_iommu_find_iommu_group(iommu, iommu_group);
 	if (!group->pinned_page_dirty_scope) {
 		group->pinned_page_dirty_scope = true;
-		update_pinned_page_dirty_scope(iommu);
+		iommu->num_non_pinned_groups--;
 	}
 
 	goto pin_done;
@@ -945,10 +1151,15 @@ static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
 
 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
 {
+	WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
 	vfio_unmap_unpin(iommu, dma, true);
 	vfio_unlink_dma(iommu, dma);
 	put_task_struct(dma->task);
 	vfio_dma_bitmap_free(dma);
+	if (dma->vaddr_invalid) {
+		iommu->vaddr_invalid_count--;
+		wake_up_all(&iommu->vaddr_wait);
+	}
 	kfree(dma);
 	iommu->dma_avail++;
 }
@@ -991,7 +1202,7 @@ static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
 	 * mark all pages dirty if any IOMMU capable device is not able
 	 * to report dirty pages and all pages are pinned and mapped.
 	 */
-	if (!iommu->pinned_page_dirty_scope && dma->iommu_mapped)
+	if (iommu->num_non_pinned_groups && dma->iommu_mapped)
 		bitmap_set(dma->bitmap, 0, nbits);
 
 	if (shift) {
@@ -1074,34 +1285,36 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
 {
 	struct vfio_dma *dma, *dma_last = NULL;
 	size_t unmapped = 0, pgsize;
-	int ret = 0, retries = 0;
+	int ret = -EINVAL, retries = 0;
 	unsigned long pgshift;
+	dma_addr_t iova = unmap->iova;
+	unsigned long size = unmap->size;
+	bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
+	bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
+	struct rb_node *n, *first_n;
 
 	mutex_lock(&iommu->lock);
 
 	pgshift = __ffs(iommu->pgsize_bitmap);
 	pgsize = (size_t)1 << pgshift;
 
-	if (unmap->iova & (pgsize - 1)) {
-		ret = -EINVAL;
+	if (iova & (pgsize - 1))
 		goto unlock;
-	}
 
-	if (!unmap->size || unmap->size & (pgsize - 1)) {
-		ret = -EINVAL;
+	if (unmap_all) {
+		if (iova || size)
+			goto unlock;
+		size = SIZE_MAX;
+	} else if (!size || size & (pgsize - 1)) {
 		goto unlock;
 	}
 
-	if (unmap->iova + unmap->size - 1 < unmap->iova ||
-	    unmap->size > SIZE_MAX) {
-		ret = -EINVAL;
+	if (iova + size - 1 < iova || size > SIZE_MAX)
 		goto unlock;
-	}
 
 	/* When dirty tracking is enabled, allow only min supported pgsize */
 	if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
 	    (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
-		ret = -EINVAL;
 		goto unlock;
 	}
 
@@ -1138,21 +1351,25 @@ again:
 	 * will only return success and a size of zero if there were no
 	 * mappings within the range.
 	 */
-	if (iommu->v2) {
-		dma = vfio_find_dma(iommu, unmap->iova, 1);
-		if (dma && dma->iova != unmap->iova) {
-			ret = -EINVAL;
+	if (iommu->v2 && !unmap_all) {
+		dma = vfio_find_dma(iommu, iova, 1);
+		if (dma && dma->iova != iova)
 			goto unlock;
-		}
-		dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
-		if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
-			ret = -EINVAL;
+
+		dma = vfio_find_dma(iommu, iova + size - 1, 0);
+		if (dma && dma->iova + dma->size != iova + size)
 			goto unlock;
-		}
 	}
 
-	while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
-		if (!iommu->v2 && unmap->iova > dma->iova)
+	ret = 0;
+	n = first_n = vfio_find_dma_first_node(iommu, iova, size);
+
+	while (n) {
+		dma = rb_entry(n, struct vfio_dma, node);
+		if (dma->iova >= iova + size)
+			break;
+
+		if (!iommu->v2 && iova > dma->iova)
 			break;
 		/*
 		 * Task with same address space who mapped this iova range is
@@ -1161,6 +1378,27 @@ again:
 		if (dma->task->mm != current->mm)
 			break;
 
+		if (invalidate_vaddr) {
+			if (dma->vaddr_invalid) {
+				struct rb_node *last_n = n;
+
+				for (n = first_n; n != last_n; n = rb_next(n)) {
+					dma = rb_entry(n,
+						       struct vfio_dma, node);
+					dma->vaddr_invalid = false;
+					iommu->vaddr_invalid_count--;
+				}
+				ret = -EINVAL;
+				unmapped = 0;
+				break;
+			}
+			dma->vaddr_invalid = true;
+			iommu->vaddr_invalid_count++;
+			unmapped += dma->size;
+			n = rb_next(n);
+			continue;
+		}
+
 		if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
 			struct vfio_iommu_type1_dma_unmap nb_unmap;
 
@@ -1190,12 +1428,13 @@ again:
 
 		if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
 			ret = update_user_bitmap(bitmap->data, iommu, dma,
-						 unmap->iova, pgsize);
+						 iova, pgsize);
 			if (ret)
 				break;
 		}
 
 		unmapped += dma->size;
+		n = rb_next(n);
 		vfio_remove_dma(iommu, dma);
 	}
 
@@ -1239,15 +1478,19 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
 {
 	dma_addr_t iova = dma->iova;
 	unsigned long vaddr = dma->vaddr;
+	struct vfio_batch batch;
 	size_t size = map_size;
 	long npage;
 	unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
 	int ret = 0;
 
+	vfio_batch_init(&batch);
+
 	while (size) {
 		/* Pin a contiguous chunk of memory */
 		npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
-					      size >> PAGE_SHIFT, &pfn, limit);
+					      size >> PAGE_SHIFT, &pfn, limit,
+					      &batch);
 		if (npage <= 0) {
 			WARN_ON(!npage);
 			ret = (int)npage;
@@ -1260,6 +1503,7 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
 		if (ret) {
 			vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
 						npage, true);
+			vfio_batch_unpin(&batch, dma);
 			break;
 		}
 
@@ -1267,6 +1511,7 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
 		dma->size += npage << PAGE_SHIFT;
 	}
 
+	vfio_batch_fini(&batch);
 	dma->iommu_mapped = true;
 
 	if (ret)
@@ -1299,6 +1544,7 @@ static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
 static int vfio_dma_do_map(struct vfio_iommu *iommu,
 			   struct vfio_iommu_type1_dma_map *map)
 {
+	bool set_vaddr = map->flags & VFIO_DMA_MAP_FLAG_VADDR;
 	dma_addr_t iova = map->iova;
 	unsigned long vaddr = map->vaddr;
 	size_t size = map->size;
@@ -1316,13 +1562,16 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
 	if (map->flags & VFIO_DMA_MAP_FLAG_READ)
 		prot |= IOMMU_READ;
 
+	if ((prot && set_vaddr) || (!prot && !set_vaddr))
+		return -EINVAL;
+
 	mutex_lock(&iommu->lock);
 
 	pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
 
 	WARN_ON((pgsize - 1) & PAGE_MASK);
 
-	if (!prot || !size || (size | iova | vaddr) & (pgsize - 1)) {
+	if (!size || (size | iova | vaddr) & (pgsize - 1)) {
 		ret = -EINVAL;
 		goto out_unlock;
 	}
@@ -1333,7 +1582,21 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
 		goto out_unlock;
 	}
 
-	if (vfio_find_dma(iommu, iova, size)) {
+	dma = vfio_find_dma(iommu, iova, size);
+	if (set_vaddr) {
+		if (!dma) {
+			ret = -ENOENT;
+		} else if (!dma->vaddr_invalid || dma->iova != iova ||
+			   dma->size != size) {
+			ret = -EINVAL;
+		} else {
+			dma->vaddr = vaddr;
+			dma->vaddr_invalid = false;
+			iommu->vaddr_invalid_count--;
+			wake_up_all(&iommu->vaddr_wait);
+		}
+		goto out_unlock;
+	} else if (dma) {
 		ret = -EEXIST;
 		goto out_unlock;
 	}
@@ -1425,16 +1688,23 @@ static int vfio_bus_type(struct device *dev, void *data)
 static int vfio_iommu_replay(struct vfio_iommu *iommu,
 			     struct vfio_domain *domain)
 {
+	struct vfio_batch batch;
 	struct vfio_domain *d = NULL;
 	struct rb_node *n;
 	unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
 	int ret;
 
+	ret = vfio_wait_all_valid(iommu);
+	if (ret < 0)
+		return ret;
+
 	/* Arbitrarily pick the first domain in the list for lookups */
 	if (!list_empty(&iommu->domain_list))
 		d = list_first_entry(&iommu->domain_list,
 				     struct vfio_domain, next);
 
+	vfio_batch_init(&batch);
+
 	n = rb_first(&iommu->dma_list);
 
 	for (; n; n = rb_next(n)) {
@@ -1482,7 +1752,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 
 				npage = vfio_pin_pages_remote(dma, vaddr,
 							      n >> PAGE_SHIFT,
-							      &pfn, limit);
+							      &pfn, limit,
+							      &batch);
 				if (npage <= 0) {
 					WARN_ON(!npage);
 					ret = (int)npage;
@@ -1496,11 +1767,13 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 			ret = iommu_map(domain->domain, iova, phys,
 					size, dma->prot | domain->prot);
 			if (ret) {
-				if (!dma->iommu_mapped)
+				if (!dma->iommu_mapped) {
 					vfio_unpin_pages_remote(dma, iova,
 							phys >> PAGE_SHIFT,
 							size >> PAGE_SHIFT,
 							true);
+					vfio_batch_unpin(&batch, dma);
+				}
 				goto unwind;
 			}
 
@@ -1515,6 +1788,7 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 		dma->iommu_mapped = true;
 	}
 
+	vfio_batch_fini(&batch);
 	return 0;
 
 unwind:
@@ -1555,6 +1829,7 @@ unwind:
 		}
 	}
 
+	vfio_batch_fini(&batch);
 	return ret;
 }
 
@@ -1622,33 +1897,6 @@ static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
 	return group;
 }
 
-static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu)
-{
-	struct vfio_domain *domain;
-	struct vfio_group *group;
-
-	list_for_each_entry(domain, &iommu->domain_list, next) {
-		list_for_each_entry(group, &domain->group_list, next) {
-			if (!group->pinned_page_dirty_scope) {
-				iommu->pinned_page_dirty_scope = false;
-				return;
-			}
-		}
-	}
-
-	if (iommu->external_domain) {
-		domain = iommu->external_domain;
-		list_for_each_entry(group, &domain->group_list, next) {
-			if (!group->pinned_page_dirty_scope) {
-				iommu->pinned_page_dirty_scope = false;
-				return;
-			}
-		}
-	}
-
-	iommu->pinned_page_dirty_scope = true;
-}
-
 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
 				  phys_addr_t *base)
 {
@@ -2057,8 +2305,6 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
 			 * addition of a dirty tracking group.
 			 */
 			group->pinned_page_dirty_scope = true;
-			if (!iommu->pinned_page_dirty_scope)
-				update_pinned_page_dirty_scope(iommu);
 			mutex_unlock(&iommu->lock);
 
 			return 0;
@@ -2188,7 +2434,7 @@ done:
 	 * demotes the iommu scope until it declares itself dirty tracking
 	 * capable via the page pinning interface.
 	 */
-	iommu->pinned_page_dirty_scope = false;
+	iommu->num_non_pinned_groups++;
 	mutex_unlock(&iommu->lock);
 	vfio_iommu_resv_free(&group_resv_regions);
 
@@ -2238,23 +2484,6 @@ static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
 	}
 }
 
-static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
-{
-	struct rb_node *n;
-
-	n = rb_first(&iommu->dma_list);
-	for (; n; n = rb_next(n)) {
-		struct vfio_dma *dma;
-
-		dma = rb_entry(n, struct vfio_dma, node);
-
-		if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
-			break;
-	}
-	/* mdev vendor driver must unregister notifier */
-	WARN_ON(iommu->notifier.head);
-}
-
 /*
  * Called when a domain is removed in detach. It is possible that
  * the removed domain decided the iova aperture window. Modify the
@@ -2354,10 +2583,10 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
 			kfree(group);
 
 			if (list_empty(&iommu->external_domain->group_list)) {
-				vfio_sanity_check_pfn_list(iommu);
-
-				if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
+				if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu)) {
+					WARN_ON(iommu->notifier.head);
 					vfio_iommu_unmap_unpin_all(iommu);
+				}
 
 				kfree(iommu->external_domain);
 				iommu->external_domain = NULL;
@@ -2391,10 +2620,12 @@ static void vfio_iommu_type1_detach_group(void *iommu_data,
 		 */
 		if (list_empty(&domain->group_list)) {
 			if (list_is_singular(&iommu->domain_list)) {
-				if (!iommu->external_domain)
+				if (!iommu->external_domain) {
+					WARN_ON(iommu->notifier.head);
 					vfio_iommu_unmap_unpin_all(iommu);
-				else
+				} else {
 					vfio_iommu_unmap_unpin_reaccount(iommu);
+				}
 			}
 			iommu_domain_free(domain->domain);
 			list_del(&domain->next);
@@ -2415,8 +2646,11 @@ detach_group_done:
 	 * Removal of a group without dirty tracking may allow the iommu scope
 	 * to be promoted.
 	 */
-	if (update_dirty_scope)
-		update_pinned_page_dirty_scope(iommu);
+	if (update_dirty_scope) {
+		iommu->num_non_pinned_groups--;
+		if (iommu->dirty_page_tracking)
+			vfio_iommu_populate_bitmap_full(iommu);
+	}
 	mutex_unlock(&iommu->lock);
 }
 
@@ -2446,8 +2680,10 @@ static void *vfio_iommu_type1_open(unsigned long arg)
 	INIT_LIST_HEAD(&iommu->iova_list);
 	iommu->dma_list = RB_ROOT;
 	iommu->dma_avail = dma_entry_limit;
+	iommu->container_open = true;
 	mutex_init(&iommu->lock);
 	BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
+	init_waitqueue_head(&iommu->vaddr_wait);
 
 	return iommu;
 }
@@ -2475,7 +2711,6 @@ static void vfio_iommu_type1_release(void *iommu_data)
 
 	if (iommu->external_domain) {
 		vfio_release_domain(iommu->external_domain, true);
-		vfio_sanity_check_pfn_list(iommu);
 		kfree(iommu->external_domain);
 	}
 
@@ -2517,6 +2752,8 @@ static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
 	case VFIO_TYPE1_IOMMU:
 	case VFIO_TYPE1v2_IOMMU:
 	case VFIO_TYPE1_NESTING_IOMMU:
+	case VFIO_UNMAP_ALL:
+	case VFIO_UPDATE_VADDR:
 		return 1;
 	case VFIO_DMA_CC_IOMMU:
 		if (!iommu)
@@ -2688,7 +2925,8 @@ static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
 {
 	struct vfio_iommu_type1_dma_map map;
 	unsigned long minsz;
-	uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE;
+	uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE |
+			VFIO_DMA_MAP_FLAG_VADDR;
 
 	minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
 
@@ -2706,6 +2944,9 @@ static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
 {
 	struct vfio_iommu_type1_dma_unmap unmap;
 	struct vfio_bitmap bitmap = { 0 };
+	uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
+			VFIO_DMA_UNMAP_FLAG_VADDR |
+			VFIO_DMA_UNMAP_FLAG_ALL;
 	unsigned long minsz;
 	int ret;
 
@@ -2714,8 +2955,12 @@ static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
 	if (copy_from_user(&unmap, (void __user *)arg, minsz))
 		return -EFAULT;
 
-	if (unmap.argsz < minsz ||
-	    unmap.flags & ~VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP)
+	if (unmap.argsz < minsz || unmap.flags & ~mask)
+		return -EINVAL;
+
+	if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
+	    (unmap.flags & (VFIO_DMA_UNMAP_FLAG_ALL |
+			    VFIO_DMA_UNMAP_FLAG_VADDR)))
 		return -EINVAL;
 
 	if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
@@ -2906,12 +3151,13 @@ static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
 	struct vfio_dma *dma;
 	bool kthread = current->mm == NULL;
 	size_t offset;
+	int ret;
 
 	*copied = 0;
 
-	dma = vfio_find_dma(iommu, user_iova, 1);
-	if (!dma)
-		return -EINVAL;
+	ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
+	if (ret < 0)
+		return ret;
 
 	if ((write && !(dma->prot & IOMMU_WRITE)) ||
 			!(dma->prot & IOMMU_READ))
@@ -3003,6 +3249,19 @@ vfio_iommu_type1_group_iommu_domain(void *iommu_data,
 	return domain;
 }
 
+static void vfio_iommu_type1_notify(void *iommu_data,
+				    enum vfio_iommu_notify_type event)
+{
+	struct vfio_iommu *iommu = iommu_data;
+
+	if (event != VFIO_IOMMU_CONTAINER_CLOSE)
+		return;
+	mutex_lock(&iommu->lock);
+	iommu->container_open = false;
+	mutex_unlock(&iommu->lock);
+	wake_up_all(&iommu->vaddr_wait);
+}
+
 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
 	.name			= "vfio-iommu-type1",
 	.owner			= THIS_MODULE,
@@ -3017,6 +3276,7 @@ static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
 	.unregister_notifier	= vfio_iommu_type1_unregister_notifier,
 	.dma_rw			= vfio_iommu_type1_dma_rw,
 	.group_iommu_domain	= vfio_iommu_type1_group_iommu_domain,
+	.notify			= vfio_iommu_type1_notify,
 };
 
 static int __init vfio_iommu_type1_init(void)