summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--Documentation/vm/hmm.rst73
-rw-r--r--arch/csky/include/asm/tlb.h8
-rw-r--r--arch/openrisc/kernel/dma.c23
-rw-r--r--arch/powerpc/mm/book3s64/subpage_prot.c12
-rw-r--r--arch/s390/mm/gmap.c35
-rw-r--r--drivers/gpu/drm/amd/amdgpu/Kconfig4
-rw-r--r--drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c2
-rw-r--r--drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c15
-rw-r--r--drivers/gpu/drm/amd/amdgpu/amdgpu_ttm.c31
-rw-r--r--drivers/gpu/drm/amd/amdkfd/kfd_priv.h3
-rw-r--r--drivers/gpu/drm/amd/amdkfd/kfd_process.c88
-rw-r--r--drivers/gpu/drm/nouveau/Kconfig5
-rw-r--r--drivers/gpu/drm/nouveau/nouveau_dmem.c456
-rw-r--r--drivers/gpu/drm/nouveau/nouveau_dmem.h11
-rw-r--r--drivers/gpu/drm/nouveau/nouveau_drm.c3
-rw-r--r--drivers/gpu/drm/nouveau/nouveau_svm.c23
-rw-r--r--drivers/gpu/drm/radeon/radeon.h3
-rw-r--r--drivers/gpu/drm/radeon/radeon_device.c2
-rw-r--r--drivers/gpu/drm/radeon/radeon_drv.c2
-rw-r--r--drivers/gpu/drm/radeon/radeon_mn.c156
-rw-r--r--drivers/infiniband/Kconfig1
-rw-r--r--drivers/infiniband/core/device.c1
-rw-r--r--drivers/infiniband/core/umem.c54
-rw-r--r--drivers/infiniband/core/umem_odp.c524
-rw-r--r--drivers/infiniband/core/uverbs_cmd.c5
-rw-r--r--drivers/infiniband/core/uverbs_main.c1
-rw-r--r--drivers/infiniband/hw/mlx5/main.c9
-rw-r--r--drivers/infiniband/hw/mlx5/mem.c13
-rw-r--r--drivers/infiniband/hw/mlx5/mr.c38
-rw-r--r--drivers/infiniband/hw/mlx5/odp.c88
-rw-r--r--drivers/misc/sgi-gru/grufile.c1
-rw-r--r--drivers/misc/sgi-gru/grutables.h2
-rw-r--r--drivers/misc/sgi-gru/grutlbpurge.c84
-rw-r--r--drivers/nvdimm/Kconfig12
-rw-r--r--drivers/nvdimm/Makefile4
-rw-r--r--fs/proc/task_mmu.c80
-rw-r--r--include/linux/hmm.h125
-rw-r--r--include/linux/ioport.h2
-rw-r--r--include/linux/kernel.h23
-rw-r--r--include/linux/memremap.h3
-rw-r--r--include/linux/migrate.h120
-rw-r--r--include/linux/mm.h46
-rw-r--r--include/linux/mm_types.h6
-rw-r--r--include/linux/mmu_notifier.h59
-rw-r--r--include/linux/pagewalk.h66
-rw-r--r--include/linux/sched.h4
-rw-r--r--include/rdma/ib_umem.h2
-rw-r--r--include/rdma/ib_umem_odp.h58
-rw-r--r--include/rdma/ib_verbs.h7
-rw-r--r--kernel/fork.c1
-rw-r--r--kernel/resource.c45
-rw-r--r--kernel/sched/core.c19
-rw-r--r--mm/Kconfig20
-rw-r--r--mm/hmm.c490
-rw-r--r--mm/madvise.c42
-rw-r--r--mm/memcontrol.c25
-rw-r--r--mm/mempolicy.c17
-rw-r--r--mm/memremap.c105
-rw-r--r--mm/migrate.c276
-rw-r--r--mm/mincore.c17
-rw-r--r--mm/mmu_notifier.c263
-rw-r--r--mm/mprotect.c26
-rw-r--r--mm/page_alloc.c2
-rw-r--r--mm/pagewalk.c126
-rw-r--r--tools/testing/nvdimm/test/iomap.c1
65 files changed, 1684 insertions, 2184 deletions
diff --git a/Documentation/vm/hmm.rst b/Documentation/vm/hmm.rst
index 710ce1c701bf..0a5960beccf7 100644
--- a/Documentation/vm/hmm.rst
+++ b/Documentation/vm/hmm.rst
@@ -192,15 +192,14 @@ read only, or fully unmap, etc.). The device must complete the update before
 the driver callback returns.
 
 When the device driver wants to populate a range of virtual addresses, it can
-use either::
+use::
 
-  long hmm_range_snapshot(struct hmm_range *range);
-  long hmm_range_fault(struct hmm_range *range, bool block);
+  long hmm_range_fault(struct hmm_range *range, unsigned int flags);
 
-The first one (hmm_range_snapshot()) will only fetch present CPU page table
+With the HMM_RANGE_SNAPSHOT flag, it will only fetch present CPU page table
 entries and will not trigger a page fault on missing or non-present entries.
-The second one does trigger a page fault on missing or read-only entries if
-write access is requested (see below). Page faults use the generic mm page
+Without that flag, it does trigger a page fault on missing or read-only entries
+if write access is requested (see below). Page faults use the generic mm page
 fault code path just like a CPU page fault.
 
 Both functions copy CPU page table entries into their pfns array argument. Each
@@ -223,24 +222,24 @@ The usage pattern is::
       range.flags = ...;
       range.values = ...;
       range.pfn_shift = ...;
-      hmm_range_register(&range);
+      hmm_range_register(&range, mirror);
 
       /*
        * Just wait for range to be valid, safe to ignore return value as we
-       * will use the return value of hmm_range_snapshot() below under the
+       * will use the return value of hmm_range_fault() below under the
        * mmap_sem to ascertain the validity of the range.
        */
       hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC);
 
  again:
       down_read(&mm->mmap_sem);
-      ret = hmm_range_snapshot(&range);
+      ret = hmm_range_fault(&range, HMM_RANGE_SNAPSHOT);
       if (ret) {
           up_read(&mm->mmap_sem);
           if (ret == -EBUSY) {
             /*
              * No need to check hmm_range_wait_until_valid() return value
-             * on retry we will get proper error with hmm_range_snapshot()
+             * on retry we will get proper error with hmm_range_fault()
              */
             hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC);
             goto again;
@@ -340,58 +339,8 @@ Migration to and from device memory
 ===================================
 
 Because the CPU cannot access device memory, migration must use the device DMA
-engine to perform copy from and to device memory. For this we need a new
-migration helper::
-
- int migrate_vma(const struct migrate_vma_ops *ops,
-                 struct vm_area_struct *vma,
-                 unsigned long mentries,
-                 unsigned long start,
-                 unsigned long end,
-                 unsigned long *src,
-                 unsigned long *dst,
-                 void *private);
-
-Unlike other migration functions it works on a range of virtual address, there
-are two reasons for that. First, device DMA copy has a high setup overhead cost
-and thus batching multiple pages is needed as otherwise the migration overhead
-makes the whole exercise pointless. The second reason is because the
-migration might be for a range of addresses the device is actively accessing.
-
-The migrate_vma_ops struct defines two callbacks. First one (alloc_and_copy())
-controls destination memory allocation and copy operation. Second one is there
-to allow the device driver to perform cleanup operations after migration::
-
- struct migrate_vma_ops {
-     void (*alloc_and_copy)(struct vm_area_struct *vma,
-                            const unsigned long *src,
-                            unsigned long *dst,
-                            unsigned long start,
-                            unsigned long end,
-                            void *private);
-     void (*finalize_and_map)(struct vm_area_struct *vma,
-                              const unsigned long *src,
-                              const unsigned long *dst,
-                              unsigned long start,
-                              unsigned long end,
-                              void *private);
- };
-
-It is important to stress that these migration helpers allow for holes in the
-virtual address range. Some pages in the range might not be migrated for all
-the usual reasons (page is pinned, page is locked, ...). This helper does not
-fail but just skips over those pages.
-
-The alloc_and_copy() might decide to not migrate all pages in the
-range (for reasons under the callback control). For those, the callback just
-has to leave the corresponding dst entry empty.
-
-Finally, the migration of the struct page might fail (for file backed page) for
-various reasons (failure to freeze reference, or update page cache, ...). If
-that happens, then the finalize_and_map() can catch any pages that were not
-migrated. Note those pages were still copied to a new page and thus we wasted
-bandwidth but this is considered as a rare event and a price that we are
-willing to pay to keep all the code simpler.
+engine to perform copy from and to device memory. For this we need to use
+migrate_vma_setup(), migrate_vma_pages(), and migrate_vma_finalize() helpers.
 
 
 Memory cgroup (memcg) and rss accounting
diff --git a/arch/csky/include/asm/tlb.h b/arch/csky/include/asm/tlb.h
index 8c7cc097666f..fdff9b8d70c8 100644
--- a/arch/csky/include/asm/tlb.h
+++ b/arch/csky/include/asm/tlb.h
@@ -8,14 +8,14 @@
 
 #define tlb_start_vma(tlb, vma) \
 	do { \
-		if (!tlb->fullmm) \
-			flush_cache_range(vma, vma->vm_start, vma->vm_end); \
+		if (!(tlb)->fullmm) \
+			flush_cache_range(vma, (vma)->vm_start, (vma)->vm_end); \
 	}  while (0)
 
 #define tlb_end_vma(tlb, vma) \
 	do { \
-		if (!tlb->fullmm) \
-			flush_tlb_range(vma, vma->vm_start, vma->vm_end); \
+		if (!(tlb)->fullmm) \
+			flush_tlb_range(vma, (vma)->vm_start, (vma)->vm_end); \
 	}  while (0)
 
 #define tlb_flush(tlb) flush_tlb_mm((tlb)->mm)
diff --git a/arch/openrisc/kernel/dma.c b/arch/openrisc/kernel/dma.c
index b41a79fcdbd9..4d5b8bd1d795 100644
--- a/arch/openrisc/kernel/dma.c
+++ b/arch/openrisc/kernel/dma.c
@@ -16,6 +16,7 @@
  */
 
 #include <linux/dma-noncoherent.h>
+#include <linux/pagewalk.h>
 
 #include <asm/cpuinfo.h>
 #include <asm/spr_defs.h>
@@ -43,6 +44,10 @@ page_set_nocache(pte_t *pte, unsigned long addr,
 	return 0;
 }
 
+static const struct mm_walk_ops set_nocache_walk_ops = {
+	.pte_entry		= page_set_nocache,
+};
+
 static int
 page_clear_nocache(pte_t *pte, unsigned long addr,
 		   unsigned long next, struct mm_walk *walk)
@@ -58,6 +63,10 @@ page_clear_nocache(pte_t *pte, unsigned long addr,
 	return 0;
 }
 
+static const struct mm_walk_ops clear_nocache_walk_ops = {
+	.pte_entry		= page_clear_nocache,
+};
+
 /*
  * Alloc "coherent" memory, which for OpenRISC means simply uncached.
  *
@@ -80,10 +89,6 @@ arch_dma_alloc(struct device *dev, size_t size, dma_addr_t *dma_handle,
 {
 	unsigned long va;
 	void *page;
-	struct mm_walk walk = {
-		.pte_entry = page_set_nocache,
-		.mm = &init_mm
-	};
 
 	page = alloc_pages_exact(size, gfp | __GFP_ZERO);
 	if (!page)
@@ -98,7 +103,8 @@ arch_dma_alloc(struct device *dev, size_t size, dma_addr_t *dma_handle,
 	 * We need to iterate through the pages, clearing the dcache for
 	 * them and setting the cache-inhibit bit.
 	 */
-	if (walk_page_range(va, va + size, &walk)) {
+	if (walk_page_range(&init_mm, va, va + size, &set_nocache_walk_ops,
+			NULL)) {
 		free_pages_exact(page, size);
 		return NULL;
 	}
@@ -111,13 +117,10 @@ arch_dma_free(struct device *dev, size_t size, void *vaddr,
 		dma_addr_t dma_handle, unsigned long attrs)
 {
 	unsigned long va = (unsigned long)vaddr;
-	struct mm_walk walk = {
-		.pte_entry = page_clear_nocache,
-		.mm = &init_mm
-	};
 
 	/* walk_page_range shouldn't be able to fail here */
-	WARN_ON(walk_page_range(va, va + size, &walk));
+	WARN_ON(walk_page_range(&init_mm, va, va + size,
+			&clear_nocache_walk_ops, NULL));
 
 	free_pages_exact(vaddr, size);
 }
diff --git a/arch/powerpc/mm/book3s64/subpage_prot.c b/arch/powerpc/mm/book3s64/subpage_prot.c
index 9ba07e55c489..2ef24a53f4c9 100644
--- a/arch/powerpc/mm/book3s64/subpage_prot.c
+++ b/arch/powerpc/mm/book3s64/subpage_prot.c
@@ -7,7 +7,7 @@
 #include <linux/kernel.h>
 #include <linux/gfp.h>
 #include <linux/types.h>
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/hugetlb.h>
 #include <linux/syscalls.h>
 
@@ -139,14 +139,14 @@ static int subpage_walk_pmd_entry(pmd_t *pmd, unsigned long addr,
 	return 0;
 }
 
+static const struct mm_walk_ops subpage_walk_ops = {
+	.pmd_entry	= subpage_walk_pmd_entry,
+};
+
 static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
 				    unsigned long len)
 {
 	struct vm_area_struct *vma;
-	struct mm_walk subpage_proto_walk = {
-		.mm = mm,
-		.pmd_entry = subpage_walk_pmd_entry,
-	};
 
 	/*
 	 * We don't try too hard, we just mark all the vma in that range
@@ -163,7 +163,7 @@ static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
 		if (vma->vm_start >= (addr + len))
 			break;
 		vma->vm_flags |= VM_NOHUGEPAGE;
-		walk_page_vma(vma, &subpage_proto_walk);
+		walk_page_vma(vma, &subpage_walk_ops, NULL);
 		vma = vma->vm_next;
 	}
 }
diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
index cd8e03f04d6d..edcdca97e85e 100644
--- a/arch/s390/mm/gmap.c
+++ b/arch/s390/mm/gmap.c
@@ -9,7 +9,7 @@
  */
 
 #include <linux/kernel.h>
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/swap.h>
 #include <linux/smp.h>
 #include <linux/spinlock.h>
@@ -2521,13 +2521,9 @@ static int __zap_zero_pages(pmd_t *pmd, unsigned long start,
 	return 0;
 }
 
-static inline void zap_zero_pages(struct mm_struct *mm)
-{
-	struct mm_walk walk = { .pmd_entry = __zap_zero_pages };
-
-	walk.mm = mm;
-	walk_page_range(0, TASK_SIZE, &walk);
-}
+static const struct mm_walk_ops zap_zero_walk_ops = {
+	.pmd_entry	= __zap_zero_pages,
+};
 
 /*
  * switch on pgstes for its userspace process (for kvm)
@@ -2546,7 +2542,7 @@ int s390_enable_sie(void)
 	mm->context.has_pgste = 1;
 	/* split thp mappings and disable thp for future mappings */
 	thp_split_mm(mm);
-	zap_zero_pages(mm);
+	walk_page_range(mm, 0, TASK_SIZE, &zap_zero_walk_ops, NULL);
 	up_write(&mm->mmap_sem);
 	return 0;
 }
@@ -2589,12 +2585,13 @@ static int __s390_enable_skey_hugetlb(pte_t *pte, unsigned long addr,
 	return 0;
 }
 
+static const struct mm_walk_ops enable_skey_walk_ops = {
+	.hugetlb_entry		= __s390_enable_skey_hugetlb,
+	.pte_entry		= __s390_enable_skey_pte,
+};
+
 int s390_enable_skey(void)
 {
-	struct mm_walk walk = {
-		.hugetlb_entry = __s390_enable_skey_hugetlb,
-		.pte_entry = __s390_enable_skey_pte,
-	};
 	struct mm_struct *mm = current->mm;
 	struct vm_area_struct *vma;
 	int rc = 0;
@@ -2614,8 +2611,7 @@ int s390_enable_skey(void)
 	}
 	mm->def_flags &= ~VM_MERGEABLE;
 
-	walk.mm = mm;
-	walk_page_range(0, TASK_SIZE, &walk);
+	walk_page_range(mm, 0, TASK_SIZE, &enable_skey_walk_ops, NULL);
 
 out_up:
 	up_write(&mm->mmap_sem);
@@ -2633,13 +2629,14 @@ static int __s390_reset_cmma(pte_t *pte, unsigned long addr,
 	return 0;
 }
 
+static const struct mm_walk_ops reset_cmma_walk_ops = {
+	.pte_entry		= __s390_reset_cmma,
+};
+
 void s390_reset_cmma(struct mm_struct *mm)
 {
-	struct mm_walk walk = { .pte_entry = __s390_reset_cmma };
-
 	down_write(&mm->mmap_sem);
-	walk.mm = mm;
-	walk_page_range(0, TASK_SIZE, &walk);
+	walk_page_range(mm, 0, TASK_SIZE, &reset_cmma_walk_ops, NULL);
 	up_write(&mm->mmap_sem);
 }
 EXPORT_SYMBOL_GPL(s390_reset_cmma);
diff --git a/drivers/gpu/drm/amd/amdgpu/Kconfig b/drivers/gpu/drm/amd/amdgpu/Kconfig
index f6e5c0282fc1..2e98c016cb47 100644
--- a/drivers/gpu/drm/amd/amdgpu/Kconfig
+++ b/drivers/gpu/drm/amd/amdgpu/Kconfig
@@ -27,7 +27,9 @@ config DRM_AMDGPU_CIK
 config DRM_AMDGPU_USERPTR
 	bool "Always enable userptr write support"
 	depends on DRM_AMDGPU
-	depends on HMM_MIRROR
+	depends on MMU
+	select HMM_MIRROR
+	select MMU_NOTIFIER
 	help
 	  This option selects CONFIG_HMM and CONFIG_HMM_MIRROR if it
 	  isn't already selected to enabled full userptr support.
diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c
index 48a2070e72f2..bdf849da32e4 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_drv.c
@@ -35,6 +35,7 @@
 #include <linux/pm_runtime.h>
 #include <linux/vga_switcheroo.h>
 #include <drm/drm_probe_helper.h>
+#include <linux/mmu_notifier.h>
 
 #include "amdgpu.h"
 #include "amdgpu_irq.h"
@@ -1469,6 +1470,7 @@ static void __exit amdgpu_exit(void)
 	amdgpu_unregister_atpx_handler();
 	amdgpu_sync_fini();
 	amdgpu_fence_slab_fini();
+	mmu_notifier_synchronize();
 }
 
 module_init(amdgpu_init);
diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c
index f1f8cdd695d3..31d4deb5d294 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c
@@ -195,13 +195,14 @@ static void amdgpu_mn_invalidate_node(struct amdgpu_mn_node *node,
  * Block for operations on BOs to finish and mark pages as accessed and
  * potentially dirty.
  */
-static int amdgpu_mn_sync_pagetables_gfx(struct hmm_mirror *mirror,
-			const struct hmm_update *update)
+static int
+amdgpu_mn_sync_pagetables_gfx(struct hmm_mirror *mirror,
+			      const struct mmu_notifier_range *update)
 {
 	struct amdgpu_mn *amn = container_of(mirror, struct amdgpu_mn, mirror);
 	unsigned long start = update->start;
 	unsigned long end = update->end;
-	bool blockable = update->blockable;
+	bool blockable = mmu_notifier_range_blockable(update);
 	struct interval_tree_node *it;
 
 	/* notification is exclusive, but interval is inclusive */
@@ -243,13 +244,14 @@ static int amdgpu_mn_sync_pagetables_gfx(struct hmm_mirror *mirror,
  * necessitates evicting all user-mode queues of the process. The BOs
  * are restorted in amdgpu_mn_invalidate_range_end_hsa.
  */
-static int amdgpu_mn_sync_pagetables_hsa(struct hmm_mirror *mirror,
-			const struct hmm_update *update)
+static int
+amdgpu_mn_sync_pagetables_hsa(struct hmm_mirror *mirror,
+			      const struct mmu_notifier_range *update)
 {
 	struct amdgpu_mn *amn = container_of(mirror, struct amdgpu_mn, mirror);
 	unsigned long start = update->start;
 	unsigned long end = update->end;
-	bool blockable = update->blockable;
+	bool blockable = mmu_notifier_range_blockable(update);
 	struct interval_tree_node *it;
 
 	/* notification is exclusive, but interval is inclusive */
@@ -482,6 +484,5 @@ void amdgpu_hmm_init_range(struct hmm_range *range)
 		range->flags = hmm_range_flags;
 		range->values = hmm_range_values;
 		range->pfn_shift = PAGE_SHIFT;
-		INIT_LIST_HEAD(&range->list);
 	}
 }
diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_ttm.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_ttm.c
index 13b144c8f67d..dff41d0a85fe 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_ttm.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_ttm.c
@@ -794,7 +794,6 @@ int amdgpu_ttm_tt_get_user_pages(struct amdgpu_bo *bo, struct page **pages)
 	struct hmm_range *range;
 	unsigned long i;
 	uint64_t *pfns;
-	int retry = 0;
 	int r = 0;
 
 	if (!mm) /* Happens during process shutdown */
@@ -835,10 +834,11 @@ int amdgpu_ttm_tt_get_user_pages(struct amdgpu_bo *bo, struct page **pages)
 				0 : range->flags[HMM_PFN_WRITE];
 	range->pfn_flags_mask = 0;
 	range->pfns = pfns;
-	hmm_range_register(range, mirror, start,
-			   start + ttm->num_pages * PAGE_SIZE, PAGE_SHIFT);
+	range->start = start;
+	range->end = start + ttm->num_pages * PAGE_SIZE;
+
+	hmm_range_register(range, mirror);
 
-retry:
 	/*
 	 * Just wait for range to be valid, safe to ignore return value as we
 	 * will use the return value of hmm_range_fault() below under the
@@ -847,24 +847,12 @@ retry:
 	hmm_range_wait_until_valid(range, HMM_RANGE_DEFAULT_TIMEOUT);
 
 	down_read(&mm->mmap_sem);
-
-	r = hmm_range_fault(range, true);
-	if (unlikely(r < 0)) {
-		if (likely(r == -EAGAIN)) {
-			/*
-			 * return -EAGAIN, mmap_sem is dropped
-			 */
-			if (retry++ < MAX_RETRY_HMM_RANGE_FAULT)
-				goto retry;
-			else
-				pr_err("Retry hmm fault too many times\n");
-		}
-
-		goto out_up_read;
-	}
-
+	r = hmm_range_fault(range, 0);
 	up_read(&mm->mmap_sem);
 
+	if (unlikely(r < 0))
+		goto out_free_pfns;
+
 	for (i = 0; i < ttm->num_pages; i++) {
 		pages[i] = hmm_device_entry_to_page(range, pfns[i]);
 		if (unlikely(!pages[i])) {
@@ -880,9 +868,6 @@ retry:
 
 	return 0;
 
-out_up_read:
-	if (likely(r != -EAGAIN))
-		up_read(&mm->mmap_sem);
 out_free_pfns:
 	hmm_range_unregister(range);
 	kvfree(pfns);
diff --git a/drivers/gpu/drm/amd/amdkfd/kfd_priv.h b/drivers/gpu/drm/amd/amdkfd/kfd_priv.h
index 3bb75d11a662..c89326125d71 100644
--- a/drivers/gpu/drm/amd/amdkfd/kfd_priv.h
+++ b/drivers/gpu/drm/amd/amdkfd/kfd_priv.h
@@ -687,9 +687,6 @@ struct kfd_process {
 	/* We want to receive a notification when the mm_struct is destroyed */
 	struct mmu_notifier mmu_notifier;
 
-	/* Use for delayed freeing of kfd_process structure */
-	struct rcu_head	rcu;
-
 	unsigned int pasid;
 	unsigned int doorbell_index;
 
diff --git a/drivers/gpu/drm/amd/amdkfd/kfd_process.c b/drivers/gpu/drm/amd/amdkfd/kfd_process.c
index 0c6ac043ae3c..40e3fc0c6942 100644
--- a/drivers/gpu/drm/amd/amdkfd/kfd_process.c
+++ b/drivers/gpu/drm/amd/amdkfd/kfd_process.c
@@ -62,8 +62,8 @@ static struct workqueue_struct *kfd_restore_wq;
 
 static struct kfd_process *find_process(const struct task_struct *thread);
 static void kfd_process_ref_release(struct kref *ref);
-static struct kfd_process *create_process(const struct task_struct *thread,
-					struct file *filep);
+static struct kfd_process *create_process(const struct task_struct *thread);
+static int kfd_process_init_cwsr_apu(struct kfd_process *p, struct file *filep);
 
 static void evict_process_worker(struct work_struct *work);
 static void restore_process_worker(struct work_struct *work);
@@ -289,7 +289,15 @@ struct kfd_process *kfd_create_process(struct file *filep)
 	if (process) {
 		pr_debug("Process already found\n");
 	} else {
-		process = create_process(thread, filep);
+		process = create_process(thread);
+		if (IS_ERR(process))
+			goto out;
+
+		ret = kfd_process_init_cwsr_apu(process, filep);
+		if (ret) {
+			process = ERR_PTR(ret);
+			goto out;
+		}
 
 		if (!procfs.kobj)
 			goto out;
@@ -478,11 +486,9 @@ static void kfd_process_ref_release(struct kref *ref)
 	queue_work(kfd_process_wq, &p->release_work);
 }
 
-static void kfd_process_destroy_delayed(struct rcu_head *rcu)
+static void kfd_process_free_notifier(struct mmu_notifier *mn)
 {
-	struct kfd_process *p = container_of(rcu, struct kfd_process, rcu);
-
-	kfd_unref_process(p);
+	kfd_unref_process(container_of(mn, struct kfd_process, mmu_notifier));
 }
 
 static void kfd_process_notifier_release(struct mmu_notifier *mn,
@@ -534,12 +540,12 @@ static void kfd_process_notifier_release(struct mmu_notifier *mn,
 
 	mutex_unlock(&p->mutex);
 
-	mmu_notifier_unregister_no_release(&p->mmu_notifier, mm);
-	mmu_notifier_call_srcu(&p->rcu, &kfd_process_destroy_delayed);
+	mmu_notifier_put(&p->mmu_notifier);
 }
 
 static const struct mmu_notifier_ops kfd_process_mmu_notifier_ops = {
 	.release = kfd_process_notifier_release,
+	.free_notifier = kfd_process_free_notifier,
 };
 
 static int kfd_process_init_cwsr_apu(struct kfd_process *p, struct file *filep)
@@ -609,81 +615,69 @@ static int kfd_process_device_init_cwsr_dgpu(struct kfd_process_device *pdd)
 	return 0;
 }
 
-static struct kfd_process *create_process(const struct task_struct *thread,
-					struct file *filep)
+/*
+ * On return the kfd_process is fully operational and will be freed when the
+ * mm is released
+ */
+static struct kfd_process *create_process(const struct task_struct *thread)
 {
 	struct kfd_process *process;
 	int err = -ENOMEM;
 
 	process = kzalloc(sizeof(*process), GFP_KERNEL);
-
 	if (!process)
 		goto err_alloc_process;
 
-	process->pasid = kfd_pasid_alloc();
-	if (process->pasid == 0)
-		goto err_alloc_pasid;
-
-	if (kfd_alloc_process_doorbells(process) < 0)
-		goto err_alloc_doorbells;
-
 	kref_init(&process->ref);
-
 	mutex_init(&process->mutex);
-
 	process->mm = thread->mm;
-
-	/* register notifier */
-	process->mmu_notifier.ops = &kfd_process_mmu_notifier_ops;
-	err = mmu_notifier_register(&process->mmu_notifier, process->mm);
-	if (err)
-		goto err_mmu_notifier;
-
-	hash_add_rcu(kfd_processes_table, &process->kfd_processes,
-			(uintptr_t)process->mm);
-
 	process->lead_thread = thread->group_leader;
-	get_task_struct(process->lead_thread);
-
 	INIT_LIST_HEAD(&process->per_device_data);
-
+	INIT_DELAYED_WORK(&process->eviction_work, evict_process_worker);
+	INIT_DELAYED_WORK(&process->restore_work, restore_process_worker);
+	process->last_restore_timestamp = get_jiffies_64();
 	kfd_event_init_process(process);
+	process->is_32bit_user_mode = in_compat_syscall();
+
+	process->pasid = kfd_pasid_alloc();
+	if (process->pasid == 0)
+		goto err_alloc_pasid;
+
+	if (kfd_alloc_process_doorbells(process) < 0)
+		goto err_alloc_doorbells;
 
 	err = pqm_init(&process->pqm, process);
 	if (err != 0)
 		goto err_process_pqm_init;
 
 	/* init process apertures*/
-	process->is_32bit_user_mode = in_compat_syscall();
 	err = kfd_init_apertures(process);
 	if (err != 0)
 		goto err_init_apertures;
 
-	INIT_DELAYED_WORK(&process->eviction_work, evict_process_worker);
-	INIT_DELAYED_WORK(&process->restore_work, restore_process_worker);
-	process->last_restore_timestamp = get_jiffies_64();
-
-	err = kfd_process_init_cwsr_apu(process, filep);
+	/* Must be last, have to use release destruction after this */
+	process->mmu_notifier.ops = &kfd_process_mmu_notifier_ops;
+	err = mmu_notifier_register(&process->mmu_notifier, process->mm);
 	if (err)
-		goto err_init_cwsr;
+		goto err_register_notifier;
+
+	get_task_struct(process->lead_thread);
+	hash_add_rcu(kfd_processes_table, &process->kfd_processes,
+			(uintptr_t)process->mm);
 
 	return process;
 
-err_init_cwsr:
+err_register_notifier:
 	kfd_process_free_outstanding_kfd_bos(process);
 	kfd_process_destroy_pdds(process);
 err_init_apertures:
 	pqm_uninit(&process->pqm);
 err_process_pqm_init:
-	hash_del_rcu(&process->kfd_processes);
-	synchronize_rcu();
-	mmu_notifier_unregister_no_release(&process->mmu_notifier, process->mm);
-err_mmu_notifier:
-	mutex_destroy(&process->mutex);
 	kfd_free_process_doorbells(process);
 err_alloc_doorbells:
 	kfd_pasid_free(process->pasid);
 err_alloc_pasid:
+	mutex_destroy(&process->mutex);
 	kfree(process);
 err_alloc_process:
 	return ERR_PTR(err);
diff --git a/drivers/gpu/drm/nouveau/Kconfig b/drivers/gpu/drm/nouveau/Kconfig
index 96b9814e6d06..3558df043592 100644
--- a/drivers/gpu/drm/nouveau/Kconfig
+++ b/drivers/gpu/drm/nouveau/Kconfig
@@ -86,9 +86,10 @@ config DRM_NOUVEAU_SVM
 	bool "(EXPERIMENTAL) Enable SVM (Shared Virtual Memory) support"
 	depends on DEVICE_PRIVATE
 	depends on DRM_NOUVEAU
-	depends on HMM_MIRROR
+	depends on MMU
 	depends on STAGING
-	select MIGRATE_VMA_HELPER
+	select HMM_MIRROR
+	select MMU_NOTIFIER
 	default n
 	help
 	  Say Y here if you want to enable experimental support for
diff --git a/drivers/gpu/drm/nouveau/nouveau_dmem.c b/drivers/gpu/drm/nouveau/nouveau_dmem.c
index 1333220787a1..fa1439941596 100644
--- a/drivers/gpu/drm/nouveau/nouveau_dmem.c
+++ b/drivers/gpu/drm/nouveau/nouveau_dmem.c
@@ -44,8 +44,6 @@
 #define DMEM_CHUNK_SIZE (2UL << 20)
 #define DMEM_CHUNK_NPAGES (DMEM_CHUNK_SIZE >> PAGE_SHIFT)
 
-struct nouveau_migrate;
-
 enum nouveau_aper {
 	NOUVEAU_APER_VIRT,
 	NOUVEAU_APER_VRAM,
@@ -86,21 +84,13 @@ static inline struct nouveau_dmem *page_to_dmem(struct page *page)
 	return container_of(page->pgmap, struct nouveau_dmem, pagemap);
 }
 
-struct nouveau_dmem_fault {
-	struct nouveau_drm *drm;
-	struct nouveau_fence *fence;
-	dma_addr_t *dma;
-	unsigned long npages;
-};
+static unsigned long nouveau_dmem_page_addr(struct page *page)
+{
+	struct nouveau_dmem_chunk *chunk = page->zone_device_data;
+	unsigned long idx = page_to_pfn(page) - chunk->pfn_first;
 
-struct nouveau_migrate {
-	struct vm_area_struct *vma;
-	struct nouveau_drm *drm;
-	struct nouveau_fence *fence;
-	unsigned long npages;
-	dma_addr_t *dma;
-	unsigned long dma_nr;
-};
+	return (idx << PAGE_SHIFT) + chunk->bo->bo.offset;
+}
 
 static void nouveau_dmem_page_free(struct page *page)
 {
@@ -125,165 +115,90 @@ static void nouveau_dmem_page_free(struct page *page)
 	spin_unlock(&chunk->lock);
 }
 
-static void
-nouveau_dmem_fault_alloc_and_copy(struct vm_area_struct *vma,
-				  const unsigned long *src_pfns,
-				  unsigned long *dst_pfns,
-				  unsigned long start,
-				  unsigned long end,
-				  void *private)
+static void nouveau_dmem_fence_done(struct nouveau_fence **fence)
 {
-	struct nouveau_dmem_fault *fault = private;
-	struct nouveau_drm *drm = fault->drm;
-	struct device *dev = drm->dev->dev;
-	unsigned long addr, i, npages = 0;
-	nouveau_migrate_copy_t copy;
-	int ret;
-
-
-	/* First allocate new memory */
-	for (addr = start, i = 0; addr < end; addr += PAGE_SIZE, i++) {
-		struct page *dpage, *spage;
-
-		dst_pfns[i] = 0;
-		spage = migrate_pfn_to_page(src_pfns[i]);
-		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE))
-			continue;
-
-		dpage = alloc_page_vma(GFP_HIGHUSER, vma, addr);
-		if (!dpage) {
-			dst_pfns[i] = MIGRATE_PFN_ERROR;
-			continue;
-		}
-		lock_page(dpage);
-
-		dst_pfns[i] = migrate_pfn(page_to_pfn(dpage)) |
-			      MIGRATE_PFN_LOCKED;
-		npages++;
-	}
-
-	/* Allocate storage for DMA addresses, so we can unmap later. */
-	fault->dma = kmalloc(sizeof(*fault->dma) * npages, GFP_KERNEL);
-	if (!fault->dma)
-		goto error;
-
-	/* Copy things over */
-	copy = drm->dmem->migrate.copy_func;
-	for (addr = start, i = 0; addr < end; addr += PAGE_SIZE, i++) {
-		struct nouveau_dmem_chunk *chunk;
-		struct page *spage, *dpage;
-		u64 src_addr, dst_addr;
-
-		dpage = migrate_pfn_to_page(dst_pfns[i]);
-		if (!dpage || dst_pfns[i] == MIGRATE_PFN_ERROR)
-			continue;
-
-		spage = migrate_pfn_to_page(src_pfns[i]);
-		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE)) {
-			dst_pfns[i] = MIGRATE_PFN_ERROR;
-			__free_page(dpage);
-			continue;
-		}
-
-		fault->dma[fault->npages] =
-			dma_map_page_attrs(dev, dpage, 0, PAGE_SIZE,
-					   PCI_DMA_BIDIRECTIONAL,
-					   DMA_ATTR_SKIP_CPU_SYNC);
-		if (dma_mapping_error(dev, fault->dma[fault->npages])) {
-			dst_pfns[i] = MIGRATE_PFN_ERROR;
-			__free_page(dpage);
-			continue;
-		}
-
-		dst_addr = fault->dma[fault->npages++];
-
-		chunk = spage->zone_device_data;
-		src_addr = page_to_pfn(spage) - chunk->pfn_first;
-		src_addr = (src_addr << PAGE_SHIFT) + chunk->bo->bo.offset;
-
-		ret = copy(drm, 1, NOUVEAU_APER_HOST, dst_addr,
-				   NOUVEAU_APER_VRAM, src_addr);
-		if (ret) {
-			dst_pfns[i] = MIGRATE_PFN_ERROR;
-			__free_page(dpage);
-			continue;
-		}
+	if (fence) {
+		nouveau_fence_wait(*fence, true, false);
+		nouveau_fence_unref(fence);
+	} else {
+		/*
+		 * FIXME wait for channel to be IDLE before calling finalizing
+		 * the hmem object.
+		 */
 	}
+}
 
-	nouveau_fence_new(drm->dmem->migrate.chan, false, &fault->fence);
-
-	return;
-
-error:
-	for (addr = start, i = 0; addr < end; addr += PAGE_SIZE, ++i) {
-		struct page *page;
+static vm_fault_t nouveau_dmem_fault_copy_one(struct nouveau_drm *drm,
+		struct vm_fault *vmf, struct migrate_vma *args,
+		dma_addr_t *dma_addr)
+{
+	struct device *dev = drm->dev->dev;
+	struct page *dpage, *spage;
 
-		if (!dst_pfns[i] || dst_pfns[i] == MIGRATE_PFN_ERROR)
-			continue;
+	spage = migrate_pfn_to_page(args->src[0]);
+	if (!spage || !(args->src[0] & MIGRATE_PFN_MIGRATE))
+		return 0;
 
-		page = migrate_pfn_to_page(dst_pfns[i]);
-		dst_pfns[i] = MIGRATE_PFN_ERROR;
-		if (page == NULL)
-			continue;
+	dpage = alloc_page_vma(GFP_HIGHUSER, vmf->vma, vmf->address);
+	if (!dpage)
+		return VM_FAULT_SIGBUS;
+	lock_page(dpage);
 
-		__free_page(page);
-	}
-}
+	*dma_addr = dma_map_page(dev, dpage, 0, PAGE_SIZE, DMA_BIDIRECTIONAL);
+	if (dma_mapping_error(dev, *dma_addr))
+		goto error_free_page;
 
-void nouveau_dmem_fault_finalize_and_map(struct vm_area_struct *vma,
-					 const unsigned long *src_pfns,
-					 const unsigned long *dst_pfns,
-					 unsigned long start,
-					 unsigned long end,
-					 void *private)
-{
-	struct nouveau_dmem_fault *fault = private;
-	struct nouveau_drm *drm = fault->drm;
+	if (drm->dmem->migrate.copy_func(drm, 1, NOUVEAU_APER_HOST, *dma_addr,
+			NOUVEAU_APER_VRAM, nouveau_dmem_page_addr(spage)))
+		goto error_dma_unmap;
 
-	if (fault->fence) {
-		nouveau_fence_wait(fault->fence, true, false);
-		nouveau_fence_unref(&fault->fence);
-	} else {
-		/*
-		 * FIXME wait for channel to be IDLE before calling finalizing
-		 * the hmem object below (nouveau_migrate_hmem_fini()).
-		 */
-	}
+	args->dst[0] = migrate_pfn(page_to_pfn(dpage)) | MIGRATE_PFN_LOCKED;
+	return 0;
 
-	while (fault->npages--) {
-		dma_unmap_page(drm->dev->dev, fault->dma[fault->npages],
-			       PAGE_SIZE, PCI_DMA_BIDIRECTIONAL);
-	}
-	kfree(fault->dma);
+error_dma_unmap:
+	dma_unmap_page(dev, *dma_addr, PAGE_SIZE, DMA_BIDIRECTIONAL);
+error_free_page:
+	__free_page(dpage);
+	return VM_FAULT_SIGBUS;
 }
 
-static const struct migrate_vma_ops nouveau_dmem_fault_migrate_ops = {
-	.alloc_and_copy		= nouveau_dmem_fault_alloc_and_copy,
-	.finalize_and_map	= nouveau_dmem_fault_finalize_and_map,
-};
-
 static vm_fault_t nouveau_dmem_migrate_to_ram(struct vm_fault *vmf)
 {
 	struct nouveau_dmem *dmem = page_to_dmem(vmf->page);
-	unsigned long src[1] = {0}, dst[1] = {0};
-	struct nouveau_dmem_fault fault = { .drm = dmem->drm };
-	int ret;
+	struct nouveau_drm *drm = dmem->drm;
+	struct nouveau_fence *fence;
+	unsigned long src = 0, dst = 0;
+	dma_addr_t dma_addr = 0;
+	vm_fault_t ret;
+	struct migrate_vma args = {
+		.vma		= vmf->vma,
+		.start		= vmf->address,
+		.end		= vmf->address + PAGE_SIZE,
+		.src		= &src,
+		.dst		= &dst,
+	};
 
 	/*
 	 * FIXME what we really want is to find some heuristic to migrate more
 	 * than just one page on CPU fault. When such fault happens it is very
 	 * likely that more surrounding page will CPU fault too.
 	 */
-	ret = migrate_vma(&nouveau_dmem_fault_migrate_ops, vmf->vma,
-			vmf->address, vmf->address + PAGE_SIZE,
-			src, dst, &fault);
-	if (ret)
+	if (migrate_vma_setup(&args) < 0)
 		return VM_FAULT_SIGBUS;
+	if (!args.cpages)
+		return 0;
 
-	if (dst[0] == MIGRATE_PFN_ERROR)
-		return VM_FAULT_SIGBUS;
+	ret = nouveau_dmem_fault_copy_one(drm, vmf, &args, &dma_addr);
+	if (ret || dst == 0)
+		goto done;
 
-	return 0;
+	nouveau_fence_new(dmem->migrate.chan, false, &fence);
+	migrate_vma_pages(&args);
+	nouveau_dmem_fence_done(&fence);
+	dma_unmap_page(drm->dev->dev, dma_addr, PAGE_SIZE, DMA_BIDIRECTIONAL);
+done:
+	migrate_vma_finalize(&args);
+	return ret;
 }
 
 static const struct dev_pagemap_ops nouveau_dmem_pagemap_ops = {
@@ -642,188 +557,115 @@ out_free:
 	drm->dmem = NULL;
 }
 
-static void
-nouveau_dmem_migrate_alloc_and_copy(struct vm_area_struct *vma,
-				    const unsigned long *src_pfns,
-				    unsigned long *dst_pfns,
-				    unsigned long start,
-				    unsigned long end,
-				    void *private)
+static unsigned long nouveau_dmem_migrate_copy_one(struct nouveau_drm *drm,
+		unsigned long src, dma_addr_t *dma_addr)
 {
-	struct nouveau_migrate *migrate = private;
-	struct nouveau_drm *drm = migrate->drm;
 	struct device *dev = drm->dev->dev;
-	unsigned long addr, i, npages = 0;
-	nouveau_migrate_copy_t copy;
-	int ret;
-
-	/* First allocate new memory */
-	for (addr = start, i = 0; addr < end; addr += PAGE_SIZE, i++) {
-		struct page *dpage, *spage;
-
-		dst_pfns[i] = 0;
-		spage = migrate_pfn_to_page(src_pfns[i]);
-		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE))
-			continue;
-
-		dpage = nouveau_dmem_page_alloc_locked(drm);
-		if (!dpage)
-			continue;
-
-		dst_pfns[i] = migrate_pfn(page_to_pfn(dpage)) |
-			      MIGRATE_PFN_LOCKED |
-			      MIGRATE_PFN_DEVICE;
-		npages++;
-	}
-
-	if (!npages)
-		return;
-
-	/* Allocate storage for DMA addresses, so we can unmap later. */
-	migrate->dma = kmalloc(sizeof(*migrate->dma) * npages, GFP_KERNEL);
-	if (!migrate->dma)
-		goto error;
-
-	/* Copy things over */
-	copy = drm->dmem->migrate.copy_func;
-	for (addr = start, i = 0; addr < end; addr += PAGE_SIZE, i++) {
-		struct nouveau_dmem_chunk *chunk;
-		struct page *spage, *dpage;
-		u64 src_addr, dst_addr;
-
-		dpage = migrate_pfn_to_page(dst_pfns[i]);
-		if (!dpage || dst_pfns[i] == MIGRATE_PFN_ERROR)
-			continue;
-
-		chunk = dpage->zone_device_data;
-		dst_addr = page_to_pfn(dpage) - chunk->pfn_first;
-		dst_addr = (dst_addr << PAGE_SHIFT) + chunk->bo->bo.offset;
-
-		spage = migrate_pfn_to_page(src_pfns[i]);
-		if (!spage || !(src_pfns[i] & MIGRATE_PFN_MIGRATE)) {
-			nouveau_dmem_page_free_locked(drm, dpage);
-			dst_pfns[i] = 0;
-			continue;
-		}
-
-		migrate->dma[migrate->dma_nr] =
-			dma_map_page_attrs(dev, spage, 0, PAGE_SIZE,
-					   PCI_DMA_BIDIRECTIONAL,
-					   DMA_ATTR_SKIP_CPU_SYNC);
-		if (dma_mapping_error(dev, migrate->dma[migrate->dma_nr])) {
-			nouveau_dmem_page_free_locked(drm, dpage);
-			dst_pfns[i] = 0;
-			continue;
-		}
-
-		src_addr = migrate->dma[migrate->dma_nr++];
+	struct page *dpage, *spage;
 
-		ret = copy(drm, 1, NOUVEAU_APER_VRAM, dst_addr,
-				   NOUVEAU_APER_HOST, src_addr);
-		if (ret) {
-			nouveau_dmem_page_free_locked(drm, dpage);
-			dst_pfns[i] = 0;
-			continue;
-		}
-	}
-
-	nouveau_fence_new(drm->dmem->migrate.chan, false, &migrate->fence);
+	spage = migrate_pfn_to_page(src);
+	if (!spage || !(src & MIGRATE_PFN_MIGRATE))
+		goto out;
 
-	return;
+	dpage = nouveau_dmem_page_alloc_locked(drm);
+	if (!dpage)
+		return 0;
 
-error:
-	for (addr = start, i = 0; addr < end; addr += PAGE_SIZE, ++i) {
-		struct page *page;
+	*dma_addr = dma_map_page(dev, spage, 0, PAGE_SIZE, DMA_BIDIRECTIONAL);
+	if (dma_mapping_error(dev, *dma_addr))
+		goto out_free_page;
 
-		if (!dst_pfns[i] || dst_pfns[i] == MIGRATE_PFN_ERROR)
-			continue;
+	if (drm->dmem->migrate.copy_func(drm, 1, NOUVEAU_APER_VRAM,
+			nouveau_dmem_page_addr(dpage), NOUVEAU_APER_HOST,
+			*dma_addr))
+		goto out_dma_unmap;
 
-		page = migrate_pfn_to_page(dst_pfns[i]);
-		dst_pfns[i] = MIGRATE_PFN_ERROR;
-		if (page == NULL)
-			continue;
+	return migrate_pfn(page_to_pfn(dpage)) | MIGRATE_PFN_LOCKED;
 
-		__free_page(page);
-	}
+out_dma_unmap:
+	dma_unmap_page(dev, *dma_addr, PAGE_SIZE, DMA_BIDIRECTIONAL);
+out_free_page:
+	nouveau_dmem_page_free_locked(drm, dpage);
+out:
+	return 0;
 }
 
-void nouveau_dmem_migrate_finalize_and_map(struct vm_area_struct *vma,
-					   const unsigned long *src_pfns,
-					   const unsigned long *dst_pfns,
-					   unsigned long start,
-					   unsigned long end,
-					   void *private)
+static void nouveau_dmem_migrate_chunk(struct nouveau_drm *drm,
+		struct migrate_vma *args, dma_addr_t *dma_addrs)
 {
-	struct nouveau_migrate *migrate = private;
-	struct nouveau_drm *drm = migrate->drm;
-
-	if (migrate->fence) {
-		nouveau_fence_wait(migrate->fence, true, false);
-		nouveau_fence_unref(&migrate->fence);
-	} else {
-		/*
-		 * FIXME wait for channel to be IDLE before finalizing
-		 * the hmem object below (nouveau_migrate_hmem_fini()) ?
-		 */
+	struct nouveau_fence *fence;
+	unsigned long addr = args->start, nr_dma = 0, i;
+
+	for (i = 0; addr < args->end; i++) {
+		args->dst[i] = nouveau_dmem_migrate_copy_one(drm, args->src[i],
+				dma_addrs + nr_dma);
+		if (args->dst[i])
+			nr_dma++;
+		addr += PAGE_SIZE;
 	}
 
-	while (migrate->dma_nr--) {
-		dma_unmap_page(drm->dev->dev, migrate->dma[migrate->dma_nr],
-			       PAGE_SIZE, PCI_DMA_BIDIRECTIONAL);
-	}
-	kfree(migrate->dma);
+	nouveau_fence_new(drm->dmem->migrate.chan, false, &fence);
+	migrate_vma_pages(args);
+	nouveau_dmem_fence_done(&fence);
 
+	while (nr_dma--) {
+		dma_unmap_page(drm->dev->dev, dma_addrs[nr_dma], PAGE_SIZE,
+				DMA_BIDIRECTIONAL);
+	}
 	/*
-	 * FIXME optimization: update GPU page table to point to newly
-	 * migrated memory.
+	 * FIXME optimization: update GPU page table to point to newly migrated
+	 * memory.
 	 */
+	migrate_vma_finalize(args);
 }
 
-static const struct migrate_vma_ops nouveau_dmem_migrate_ops = {
-	.alloc_and_copy		= nouveau_dmem_migrate_alloc_and_copy,
-	.finalize_and_map	= nouveau_dmem_migrate_finalize_and_map,
-};
-
 int
 nouveau_dmem_migrate_vma(struct nouveau_drm *drm,
 			 struct vm_area_struct *vma,
 			 unsigned long start,
 			 unsigned long end)
 {
-	unsigned long *src_pfns, *dst_pfns, npages;
-	struct nouveau_migrate migrate = {0};
-	unsigned long i, c, max;
-	int ret = 0;
-
-	npages = (end - start) >> PAGE_SHIFT;
-	max = min(SG_MAX_SINGLE_ALLOC, npages);
-	src_pfns = kzalloc(sizeof(long) * max, GFP_KERNEL);
-	if (src_pfns == NULL)
-		return -ENOMEM;
-	dst_pfns = kzalloc(sizeof(long) * max, GFP_KERNEL);
-	if (dst_pfns == NULL) {
-		kfree(src_pfns);
-		return -ENOMEM;
-	}
+	unsigned long npages = (end - start) >> PAGE_SHIFT;
+	unsigned long max = min(SG_MAX_SINGLE_ALLOC, npages);
+	dma_addr_t *dma_addrs;
+	struct migrate_vma args = {
+		.vma		= vma,
+		.start		= start,
+	};
+	unsigned long c, i;
+	int ret = -ENOMEM;
+
+	args.src = kcalloc(max, sizeof(args.src), GFP_KERNEL);
+	if (!args.src)
+		goto out;
+	args.dst = kcalloc(max, sizeof(args.dst), GFP_KERNEL);
+	if (!args.dst)
+		goto out_free_src;
 
-	migrate.drm = drm;
-	migrate.vma = vma;
-	migrate.npages = npages;
-	for (i = 0; i < npages; i += c) {
-		unsigned long next;
+	dma_addrs = kmalloc_array(max, sizeof(*dma_addrs), GFP_KERNEL);
+	if (!dma_addrs)
+		goto out_free_dst;
 
+	for (i = 0; i < npages; i += c) {
 		c = min(SG_MAX_SINGLE_ALLOC, npages);
-		next = start + (c << PAGE_SHIFT);
-		ret = migrate_vma(&nouveau_dmem_migrate_ops, vma, start,
-				  next, src_pfns, dst_pfns, &migrate);
+		args.end = start + (c << PAGE_SHIFT);
+		ret = migrate_vma_setup(&args);
 		if (ret)
-			goto out;
-		start = next;
+			goto out_free_dma;
+
+		if (args.cpages)
+			nouveau_dmem_migrate_chunk(drm, &args, dma_addrs);
+		args.start = args.end;
 	}
 
+	ret = 0;
+out_free_dma:
+	kfree(dma_addrs);
+out_free_dst:
+	kfree(args.dst);
+out_free_src:
+	kfree(args.src);
 out:
-	kfree(dst_pfns);
-	kfree(src_pfns);
 	return ret;
 }
 
@@ -841,11 +683,10 @@ nouveau_dmem_convert_pfn(struct nouveau_drm *drm,
 
 	npages = (range->end - range->start) >> PAGE_SHIFT;
 	for (i = 0; i < npages; ++i) {
-		struct nouveau_dmem_chunk *chunk;
 		struct page *page;
 		uint64_t addr;
 
-		page = hmm_pfn_to_page(range, range->pfns[i]);
+		page = hmm_device_entry_to_page(range, range->pfns[i]);
 		if (page == NULL)
 			continue;
 
@@ -859,10 +700,7 @@ nouveau_dmem_convert_pfn(struct nouveau_drm *drm,
 			continue;
 		}
 
-		chunk = page->zone_device_data;
-		addr = page_to_pfn(page) - chunk->pfn_first;
-		addr = (addr + chunk->bo->bo.mem.start) << PAGE_SHIFT;
-
+		addr = nouveau_dmem_page_addr(page);
 		range->pfns[i] &= ((1UL << range->pfn_shift) - 1);
 		range->pfns[i] |= (addr >> PAGE_SHIFT) << range->pfn_shift;
 	}
diff --git a/drivers/gpu/drm/nouveau/nouveau_dmem.h b/drivers/gpu/drm/nouveau/nouveau_dmem.h
index 9d97d756fb7d..92394be5d649 100644
--- a/drivers/gpu/drm/nouveau/nouveau_dmem.h
+++ b/drivers/gpu/drm/nouveau/nouveau_dmem.h
@@ -45,16 +45,5 @@ static inline void nouveau_dmem_init(struct nouveau_drm *drm) {}
 static inline void nouveau_dmem_fini(struct nouveau_drm *drm) {}
 static inline void nouveau_dmem_suspend(struct nouveau_drm *drm) {}
 static inline void nouveau_dmem_resume(struct nouveau_drm *drm) {}
-
-static inline int nouveau_dmem_migrate_vma(struct nouveau_drm *drm,
-					   struct vm_area_struct *vma,
-					   unsigned long start,
-					   unsigned long end)
-{
-	return 0;
-}
-
-static inline void nouveau_dmem_convert_pfn(struct nouveau_drm *drm,
-					    struct hmm_range *range) {}
 #endif /* IS_ENABLED(CONFIG_DRM_NOUVEAU_SVM) */
 #endif
diff --git a/drivers/gpu/drm/nouveau/nouveau_drm.c b/drivers/gpu/drm/nouveau/nouveau_drm.c
index bdc948352467..2cd83849600f 100644
--- a/drivers/gpu/drm/nouveau/nouveau_drm.c
+++ b/drivers/gpu/drm/nouveau/nouveau_drm.c
@@ -28,6 +28,7 @@
 #include <linux/pci.h>
 #include <linux/pm_runtime.h>
 #include <linux/vga_switcheroo.h>
+#include <linux/mmu_notifier.h>
 
 #include <drm/drm_crtc_helper.h>
 #include <drm/drm_ioctl.h>
@@ -1290,6 +1291,8 @@ nouveau_drm_exit(void)
 #ifdef CONFIG_NOUVEAU_PLATFORM_DRIVER
 	platform_driver_unregister(&nouveau_platform_driver);
 #endif
+	if (IS_ENABLED(CONFIG_DRM_NOUVEAU_SVM))
+		mmu_notifier_synchronize();
 }
 
 module_init(nouveau_drm_init);
diff --git a/drivers/gpu/drm/nouveau/nouveau_svm.c b/drivers/gpu/drm/nouveau/nouveau_svm.c
index a835cebb6d90..668d4bd0c118 100644
--- a/drivers/gpu/drm/nouveau/nouveau_svm.c
+++ b/drivers/gpu/drm/nouveau/nouveau_svm.c
@@ -252,13 +252,13 @@ nouveau_svmm_invalidate(struct nouveau_svmm *svmm, u64 start, u64 limit)
 
 static int
 nouveau_svmm_sync_cpu_device_pagetables(struct hmm_mirror *mirror,
-					const struct hmm_update *update)
+					const struct mmu_notifier_range *update)
 {
 	struct nouveau_svmm *svmm = container_of(mirror, typeof(*svmm), mirror);
 	unsigned long start = update->start;
 	unsigned long limit = update->end;
 
-	if (!update->blockable)
+	if (!mmu_notifier_range_blockable(update))
 		return -EAGAIN;
 
 	SVMM_DBG(svmm, "invalidate %016lx-%016lx", start, limit);
@@ -485,31 +485,29 @@ nouveau_range_done(struct hmm_range *range)
 }
 
 static int
-nouveau_range_fault(struct hmm_mirror *mirror, struct hmm_range *range)
+nouveau_range_fault(struct nouveau_svmm *svmm, struct hmm_range *range)
 {
 	long ret;
 
 	range->default_flags = 0;
 	range->pfn_flags_mask = -1UL;
 
-	ret = hmm_range_register(range, mirror,
-				 range->start, range->end,
-				 PAGE_SHIFT);
+	ret = hmm_range_register(range, &svmm->mirror);
 	if (ret) {
-		up_read(&range->vma->vm_mm->mmap_sem);
+		up_read(&svmm->mm->mmap_sem);
 		return (int)ret;
 	}
 
 	if (!hmm_range_wait_until_valid(range, HMM_RANGE_DEFAULT_TIMEOUT)) {
-		up_read(&range->vma->vm_mm->mmap_sem);
-		return -EAGAIN;
+		up_read(&svmm->mm->mmap_sem);
+		return -EBUSY;
 	}
 
-	ret = hmm_range_fault(range, true);
+	ret = hmm_range_fault(range, 0);
 	if (ret <= 0) {
 		if (ret == 0)
 			ret = -EBUSY;
-		up_read(&range->vma->vm_mm->mmap_sem);
+		up_read(&svmm->mm->mmap_sem);
 		hmm_range_unregister(range);
 		return ret;
 	}
@@ -682,7 +680,6 @@ nouveau_svm_fault(struct nvif_notify *notify)
 			 args.i.p.addr + args.i.p.size, fn - fi);
 
 		/* Have HMM fault pages within the fault window to the GPU. */
-		range.vma = vma;
 		range.start = args.i.p.addr;
 		range.end = args.i.p.addr + args.i.p.size;
 		range.pfns = args.phys;
@@ -690,7 +687,7 @@ nouveau_svm_fault(struct nvif_notify *notify)
 		range.values = nouveau_svm_pfn_values;
 		range.pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT;
 again:
-		ret = nouveau_range_fault(&svmm->mirror, &range);
+		ret = nouveau_range_fault(svmm, &range);
 		if (ret == 0) {
 			mutex_lock(&svmm->mutex);
 			if (!nouveau_range_done(&range)) {
diff --git a/drivers/gpu/drm/radeon/radeon.h b/drivers/gpu/drm/radeon/radeon.h
index 05b88491ccb9..d59b004f6695 100644
--- a/drivers/gpu/drm/radeon/radeon.h
+++ b/drivers/gpu/drm/radeon/radeon.h
@@ -2449,9 +2449,6 @@ struct radeon_device {
 	/* tracking pinned memory */
 	u64 vram_pin_size;
 	u64 gart_pin_size;
-
-	struct mutex	mn_lock;
-	DECLARE_HASHTABLE(mn_hash, 7);
 };
 
 bool radeon_is_px(struct drm_device *dev);
diff --git a/drivers/gpu/drm/radeon/radeon_device.c b/drivers/gpu/drm/radeon/radeon_device.c
index 88eb7cb522bb..5d017f0aec66 100644
--- a/drivers/gpu/drm/radeon/radeon_device.c
+++ b/drivers/gpu/drm/radeon/radeon_device.c
@@ -1325,8 +1325,6 @@ int radeon_device_init(struct radeon_device *rdev,
 	init_rwsem(&rdev->pm.mclk_lock);
 	init_rwsem(&rdev->exclusive_lock);
 	init_waitqueue_head(&rdev->irq.vblank_queue);
-	mutex_init(&rdev->mn_lock);
-	hash_init(rdev->mn_hash);
 	r = radeon_gem_init(rdev);
 	if (r)
 		return r;
diff --git a/drivers/gpu/drm/radeon/radeon_drv.c b/drivers/gpu/drm/radeon/radeon_drv.c
index 5838162f687f..431e6b64b77d 100644
--- a/drivers/gpu/drm/radeon/radeon_drv.c
+++ b/drivers/gpu/drm/radeon/radeon_drv.c
@@ -35,6 +35,7 @@
 #include <linux/module.h>
 #include <linux/pm_runtime.h>
 #include <linux/vga_switcheroo.h>
+#include <linux/mmu_notifier.h>
 
 #include <drm/drm_crtc_helper.h>
 #include <drm/drm_drv.h>
@@ -623,6 +624,7 @@ static void __exit radeon_exit(void)
 {
 	pci_unregister_driver(pdriver);
 	radeon_unregister_atpx_handler();
+	mmu_notifier_synchronize();
 }
 
 module_init(radeon_init);
diff --git a/drivers/gpu/drm/radeon/radeon_mn.c b/drivers/gpu/drm/radeon/radeon_mn.c
index 6902f998ede9..dbab9a3a969b 100644
--- a/drivers/gpu/drm/radeon/radeon_mn.c
+++ b/drivers/gpu/drm/radeon/radeon_mn.c
@@ -37,17 +37,8 @@
 #include "radeon.h"
 
 struct radeon_mn {
-	/* constant after initialisation */
-	struct radeon_device	*rdev;
-	struct mm_struct	*mm;
 	struct mmu_notifier	mn;
 
-	/* only used on destruction */
-	struct work_struct	work;
-
-	/* protected by rdev->mn_lock */
-	struct hlist_node	node;
-
 	/* objects protected by lock */
 	struct mutex		lock;
 	struct rb_root_cached	objects;
@@ -59,55 +50,6 @@ struct radeon_mn_node {
 };
 
 /**
- * radeon_mn_destroy - destroy the rmn
- *
- * @work: previously sheduled work item
- *
- * Lazy destroys the notifier from a work item
- */
-static void radeon_mn_destroy(struct work_struct *work)
-{
-	struct radeon_mn *rmn = container_of(work, struct radeon_mn, work);
-	struct radeon_device *rdev = rmn->rdev;
-	struct radeon_mn_node *node, *next_node;
-	struct radeon_bo *bo, *next_bo;
-
-	mutex_lock(&rdev->mn_lock);
-	mutex_lock(&rmn->lock);
-	hash_del(&rmn->node);
-	rbtree_postorder_for_each_entry_safe(node, next_node,
-					     &rmn->objects.rb_root, it.rb) {
-
-		interval_tree_remove(&node->it, &rmn->objects);
-		list_for_each_entry_safe(bo, next_bo, &node->bos, mn_list) {
-			bo->mn = NULL;
-			list_del_init(&bo->mn_list);
-		}
-		kfree(node);
-	}
-	mutex_unlock(&rmn->lock);
-	mutex_unlock(&rdev->mn_lock);
-	mmu_notifier_unregister(&rmn->mn, rmn->mm);
-	kfree(rmn);
-}
-
-/**
- * radeon_mn_release - callback to notify about mm destruction
- *
- * @mn: our notifier
- * @mn: the mm this callback is about
- *
- * Shedule a work item to lazy destroy our notifier.
- */
-static void radeon_mn_release(struct mmu_notifier *mn,
-			      struct mm_struct *mm)
-{
-	struct radeon_mn *rmn = container_of(mn, struct radeon_mn, mn);
-	INIT_WORK(&rmn->work, radeon_mn_destroy);
-	schedule_work(&rmn->work);
-}
-
-/**
  * radeon_mn_invalidate_range_start - callback to notify about mm change
  *
  * @mn: our notifier
@@ -183,65 +125,44 @@ out_unlock:
 	return ret;
 }
 
-static const struct mmu_notifier_ops radeon_mn_ops = {
-	.release = radeon_mn_release,
-	.invalidate_range_start = radeon_mn_invalidate_range_start,
-};
+static void radeon_mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
+{
+	struct mmu_notifier_range range = {
+		.mm = mm,
+		.start = 0,
+		.end = ULONG_MAX,
+		.flags = 0,
+		.event = MMU_NOTIFY_UNMAP,
+	};
+
+	radeon_mn_invalidate_range_start(mn, &range);
+}
 
-/**
- * radeon_mn_get - create notifier context
- *
- * @rdev: radeon device pointer
- *
- * Creates a notifier context for current->mm.
- */
-static struct radeon_mn *radeon_mn_get(struct radeon_device *rdev)
+static struct mmu_notifier *radeon_mn_alloc_notifier(struct mm_struct *mm)
 {
-	struct mm_struct *mm = current->mm;
 	struct radeon_mn *rmn;
-	int r;
-
-	if (down_write_killable(&mm->mmap_sem))
-		return ERR_PTR(-EINTR);
-
-	mutex_lock(&rdev->mn_lock);
-
-	hash_for_each_possible(rdev->mn_hash, rmn, node, (unsigned long)mm)
-		if (rmn->mm == mm)
-			goto release_locks;
 
 	rmn = kzalloc(sizeof(*rmn), GFP_KERNEL);
-	if (!rmn) {
-		rmn = ERR_PTR(-ENOMEM);
-		goto release_locks;
-	}
+	if (!rmn)
+		return ERR_PTR(-ENOMEM);
 
-	rmn->rdev = rdev;
-	rmn->mm = mm;
-	rmn->mn.ops = &radeon_mn_ops;
 	mutex_init(&rmn->lock);
 	rmn->objects = RB_ROOT_CACHED;
-	
-	r = __mmu_notifier_register(&rmn->mn, mm);
-	if (r)
-		goto free_rmn;
-
-	hash_add(rdev->mn_hash, &rmn->node, (unsigned long)mm);
-
-release_locks:
-	mutex_unlock(&rdev->mn_lock);
-	up_write(&mm->mmap_sem);
-
-	return rmn;
-
-free_rmn:
-	mutex_unlock(&rdev->mn_lock);
-	up_write(&mm->mmap_sem);
-	kfree(rmn);
+	return &rmn->mn;
+}
 
-	return ERR_PTR(r);
+static void radeon_mn_free_notifier(struct mmu_notifier *mn)
+{
+	kfree(container_of(mn, struct radeon_mn, mn));
 }
 
+static const struct mmu_notifier_ops radeon_mn_ops = {
+	.release = radeon_mn_release,
+	.invalidate_range_start = radeon_mn_invalidate_range_start,
+	.alloc_notifier = radeon_mn_alloc_notifier,
+	.free_notifier = radeon_mn_free_notifier,
+};
+
 /**
  * radeon_mn_register - register a BO for notifier updates
  *
@@ -254,15 +175,16 @@ free_rmn:
 int radeon_mn_register(struct radeon_bo *bo, unsigned long addr)
 {
 	unsigned long end = addr + radeon_bo_size(bo) - 1;
-	struct radeon_device *rdev = bo->rdev;
+	struct mmu_notifier *mn;
 	struct radeon_mn *rmn;
 	struct radeon_mn_node *node = NULL;
 	struct list_head bos;
 	struct interval_tree_node *it;
 
-	rmn = radeon_mn_get(rdev);
-	if (IS_ERR(rmn))
-		return PTR_ERR(rmn);
+	mn = mmu_notifier_get(&radeon_mn_ops, current->mm);
+	if (IS_ERR(mn))
+		return PTR_ERR(mn);
+	rmn = container_of(mn, struct radeon_mn, mn);
 
 	INIT_LIST_HEAD(&bos);
 
@@ -309,22 +231,16 @@ int radeon_mn_register(struct radeon_bo *bo, unsigned long addr)
  */
 void radeon_mn_unregister(struct radeon_bo *bo)
 {
-	struct radeon_device *rdev = bo->rdev;
-	struct radeon_mn *rmn;
+	struct radeon_mn *rmn = bo->mn;
 	struct list_head *head;
 
-	mutex_lock(&rdev->mn_lock);
-	rmn = bo->mn;
-	if (rmn == NULL) {
-		mutex_unlock(&rdev->mn_lock);
+	if (!rmn)
 		return;
-	}
 
 	mutex_lock(&rmn->lock);
 	/* save the next list entry for later */
 	head = bo->mn_list.next;
 
-	bo->mn = NULL;
 	list_del(&bo->mn_list);
 
 	if (list_empty(head)) {
@@ -335,5 +251,7 @@ void radeon_mn_unregister(struct radeon_bo *bo)
 	}
 
 	mutex_unlock(&rmn->lock);
-	mutex_unlock(&rdev->mn_lock);
+
+	mmu_notifier_put(&rmn->mn);
+	bo->mn = NULL;
 }
diff --git a/drivers/infiniband/Kconfig b/drivers/infiniband/Kconfig
index 85e103b147cc..b44b1c322ec8 100644
--- a/drivers/infiniband/Kconfig
+++ b/drivers/infiniband/Kconfig
@@ -55,6 +55,7 @@ config INFINIBAND_ON_DEMAND_PAGING
 	bool "InfiniBand on-demand paging support"
 	depends on INFINIBAND_USER_MEM
 	select MMU_NOTIFIER
+	select INTERVAL_TREE
 	default y
 	---help---
 	  On demand paging support for the InfiniBand subsystem.
diff --git a/drivers/infiniband/core/device.c b/drivers/infiniband/core/device.c
index ea8661a00651..b5631b8a0397 100644
--- a/drivers/infiniband/core/device.c
+++ b/drivers/infiniband/core/device.c
@@ -2562,6 +2562,7 @@ void ib_set_device_ops(struct ib_device *dev, const struct ib_device_ops *ops)
 	SET_DEVICE_OP(dev_ops, get_vf_config);
 	SET_DEVICE_OP(dev_ops, get_vf_stats);
 	SET_DEVICE_OP(dev_ops, init_port);
+	SET_DEVICE_OP(dev_ops, invalidate_range);
 	SET_DEVICE_OP(dev_ops, iw_accept);
 	SET_DEVICE_OP(dev_ops, iw_add_ref);
 	SET_DEVICE_OP(dev_ops, iw_connect);
diff --git a/drivers/infiniband/core/umem.c b/drivers/infiniband/core/umem.c
index 56553668256f..41f9e268e3fb 100644
--- a/drivers/infiniband/core/umem.c
+++ b/drivers/infiniband/core/umem.c
@@ -184,9 +184,6 @@ EXPORT_SYMBOL(ib_umem_find_best_pgsz);
 /**
  * ib_umem_get - Pin and DMA map userspace memory.
  *
- * If access flags indicate ODP memory, avoid pinning. Instead, stores
- * the mm for future page fault handling in conjunction with MMU notifiers.
- *
  * @udata: userspace context to pin memory for
  * @addr: userspace virtual address to start at
  * @size: length of region to pin
@@ -231,36 +228,19 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
 	if (!can_do_mlock())
 		return ERR_PTR(-EPERM);
 
-	if (access & IB_ACCESS_ON_DEMAND) {
-		umem = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
-		if (!umem)
-			return ERR_PTR(-ENOMEM);
-		umem->is_odp = 1;
-	} else {
-		umem = kzalloc(sizeof(*umem), GFP_KERNEL);
-		if (!umem)
-			return ERR_PTR(-ENOMEM);
-	}
+	if (access & IB_ACCESS_ON_DEMAND)
+		return ERR_PTR(-EOPNOTSUPP);
 
-	umem->context    = context;
+	umem = kzalloc(sizeof(*umem), GFP_KERNEL);
+	if (!umem)
+		return ERR_PTR(-ENOMEM);
+	umem->ibdev = context->device;
 	umem->length     = size;
 	umem->address    = addr;
 	umem->writable   = ib_access_writable(access);
 	umem->owning_mm = mm = current->mm;
 	mmgrab(mm);
 
-	if (access & IB_ACCESS_ON_DEMAND) {
-		if (WARN_ON_ONCE(!context->invalidate_range)) {
-			ret = -EINVAL;
-			goto umem_kfree;
-		}
-
-		ret = ib_umem_odp_get(to_ib_umem_odp(umem), access);
-		if (ret)
-			goto umem_kfree;
-		return umem;
-	}
-
 	page_list = (struct page **) __get_free_page(GFP_KERNEL);
 	if (!page_list) {
 		ret = -ENOMEM;
@@ -346,15 +326,6 @@ umem_kfree:
 }
 EXPORT_SYMBOL(ib_umem_get);
 
-static void __ib_umem_release_tail(struct ib_umem *umem)
-{
-	mmdrop(umem->owning_mm);
-	if (umem->is_odp)
-		kfree(to_ib_umem_odp(umem));
-	else
-		kfree(umem);
-}
-
 /**
  * ib_umem_release - release memory pinned with ib_umem_get
  * @umem: umem struct to release
@@ -363,17 +334,14 @@ void ib_umem_release(struct ib_umem *umem)
 {
 	if (!umem)
 		return;
+	if (umem->is_odp)
+		return ib_umem_odp_release(to_ib_umem_odp(umem));
 
-	if (umem->is_odp) {
-		ib_umem_odp_release(to_ib_umem_odp(umem));
-		__ib_umem_release_tail(umem);
-		return;
-	}
-
-	__ib_umem_release(umem->context->device, umem, 1);
+	__ib_umem_release(umem->ibdev, umem, 1);
 
 	atomic64_sub(ib_umem_num_pages(umem), &umem->owning_mm->pinned_vm);
-	__ib_umem_release_tail(umem);
+	mmdrop(umem->owning_mm);
+	kfree(umem);
 }
 EXPORT_SYMBOL(ib_umem_release);
 
diff --git a/drivers/infiniband/core/umem_odp.c b/drivers/infiniband/core/umem_odp.c
index c0e15db34680..9aebe9ce8b07 100644
--- a/drivers/infiniband/core/umem_odp.c
+++ b/drivers/infiniband/core/umem_odp.c
@@ -39,44 +39,14 @@
 #include <linux/export.h>
 #include <linux/vmalloc.h>
 #include <linux/hugetlb.h>
-#include <linux/interval_tree_generic.h>
+#include <linux/interval_tree.h>
 #include <linux/pagemap.h>
 
 #include <rdma/ib_verbs.h>
 #include <rdma/ib_umem.h>
 #include <rdma/ib_umem_odp.h>
 
-/*
- * The ib_umem list keeps track of memory regions for which the HW
- * device request to receive notification when the related memory
- * mapping is changed.
- *
- * ib_umem_lock protects the list.
- */
-
-static u64 node_start(struct umem_odp_node *n)
-{
-	struct ib_umem_odp *umem_odp =
-			container_of(n, struct ib_umem_odp, interval_tree);
-
-	return ib_umem_start(umem_odp);
-}
-
-/* Note that the representation of the intervals in the interval tree
- * considers the ending point as contained in the interval, while the
- * function ib_umem_end returns the first address which is not contained
- * in the umem.
- */
-static u64 node_last(struct umem_odp_node *n)
-{
-	struct ib_umem_odp *umem_odp =
-			container_of(n, struct ib_umem_odp, interval_tree);
-
-	return ib_umem_end(umem_odp) - 1;
-}
-
-INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
-		     node_start, node_last, static, rbt_ib_umem)
+#include "uverbs.h"
 
 static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
 {
@@ -104,31 +74,34 @@ static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
 	mutex_unlock(&umem_odp->umem_mutex);
 }
 
-static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
-					       u64 start, u64 end, void *cookie)
-{
-	/*
-	 * Increase the number of notifiers running, to
-	 * prevent any further fault handling on this MR.
-	 */
-	ib_umem_notifier_start_account(umem_odp);
-	complete_all(&umem_odp->notifier_completion);
-	umem_odp->umem.context->invalidate_range(
-		umem_odp, ib_umem_start(umem_odp), ib_umem_end(umem_odp));
-	return 0;
-}
-
 static void ib_umem_notifier_release(struct mmu_notifier *mn,
 				     struct mm_struct *mm)
 {
 	struct ib_ucontext_per_mm *per_mm =
 		container_of(mn, struct ib_ucontext_per_mm, mn);
+	struct rb_node *node;
 
 	down_read(&per_mm->umem_rwsem);
-	if (per_mm->active)
-		rbt_ib_umem_for_each_in_range(
-			&per_mm->umem_tree, 0, ULLONG_MAX,
-			ib_umem_notifier_release_trampoline, true, NULL);
+	if (!per_mm->mn.users)
+		goto out;
+
+	for (node = rb_first_cached(&per_mm->umem_tree); node;
+	     node = rb_next(node)) {
+		struct ib_umem_odp *umem_odp =
+			rb_entry(node, struct ib_umem_odp, interval_tree.rb);
+
+		/*
+		 * Increase the number of notifiers running, to prevent any
+		 * further fault handling on this MR.
+		 */
+		ib_umem_notifier_start_account(umem_odp);
+		complete_all(&umem_odp->notifier_completion);
+		umem_odp->umem.ibdev->ops.invalidate_range(
+			umem_odp, ib_umem_start(umem_odp),
+			ib_umem_end(umem_odp));
+	}
+
+out:
 	up_read(&per_mm->umem_rwsem);
 }
 
@@ -136,7 +109,7 @@ static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
 					     u64 start, u64 end, void *cookie)
 {
 	ib_umem_notifier_start_account(item);
-	item->umem.context->invalidate_range(item, start, end);
+	item->umem.ibdev->ops.invalidate_range(item, start, end);
 	return 0;
 }
 
@@ -152,10 +125,10 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
 	else if (!down_read_trylock(&per_mm->umem_rwsem))
 		return -EAGAIN;
 
-	if (!per_mm->active) {
+	if (!per_mm->mn.users) {
 		up_read(&per_mm->umem_rwsem);
 		/*
-		 * At this point active is permanently set and visible to this
+		 * At this point users is permanently zero and visible to this
 		 * CPU without a lock, that fact is relied on to skip the unlock
 		 * in range_end.
 		 */
@@ -185,7 +158,7 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
 	struct ib_ucontext_per_mm *per_mm =
 		container_of(mn, struct ib_ucontext_per_mm, mn);
 
-	if (unlikely(!per_mm->active))
+	if (unlikely(!per_mm->mn.users))
 		return;
 
 	rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
@@ -194,212 +167,250 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
 	up_read(&per_mm->umem_rwsem);
 }
 
-static const struct mmu_notifier_ops ib_umem_notifiers = {
-	.release                    = ib_umem_notifier_release,
-	.invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
-	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
-};
-
-static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
-{
-	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
-
-	down_write(&per_mm->umem_rwsem);
-	if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
-		rbt_ib_umem_insert(&umem_odp->interval_tree,
-				   &per_mm->umem_tree);
-	up_write(&per_mm->umem_rwsem);
-}
-
-static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
-{
-	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
-
-	down_write(&per_mm->umem_rwsem);
-	if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
-		rbt_ib_umem_remove(&umem_odp->interval_tree,
-				   &per_mm->umem_tree);
-	complete_all(&umem_odp->notifier_completion);
-
-	up_write(&per_mm->umem_rwsem);
-}
-
-static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
-					       struct mm_struct *mm)
+static struct mmu_notifier *ib_umem_alloc_notifier(struct mm_struct *mm)
 {
 	struct ib_ucontext_per_mm *per_mm;
-	int ret;
 
 	per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
 	if (!per_mm)
 		return ERR_PTR(-ENOMEM);
 
-	per_mm->context = ctx;
-	per_mm->mm = mm;
 	per_mm->umem_tree = RB_ROOT_CACHED;
 	init_rwsem(&per_mm->umem_rwsem);
-	per_mm->active = true;
 
+	WARN_ON(mm != current->mm);
 	rcu_read_lock();
 	per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
 	rcu_read_unlock();
+	return &per_mm->mn;
+}
 
-	WARN_ON(mm != current->mm);
-
-	per_mm->mn.ops = &ib_umem_notifiers;
-	ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
-	if (ret) {
-		dev_err(&ctx->device->dev,
-			"Failed to register mmu_notifier %d\n", ret);
-		goto out_pid;
-	}
+static void ib_umem_free_notifier(struct mmu_notifier *mn)
+{
+	struct ib_ucontext_per_mm *per_mm =
+		container_of(mn, struct ib_ucontext_per_mm, mn);
 
-	list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
-	return per_mm;
+	WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
 
-out_pid:
 	put_pid(per_mm->tgid);
 	kfree(per_mm);
-	return ERR_PTR(ret);
 }
 
-static int get_per_mm(struct ib_umem_odp *umem_odp)
+static const struct mmu_notifier_ops ib_umem_notifiers = {
+	.release                    = ib_umem_notifier_release,
+	.invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
+	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
+	.alloc_notifier		    = ib_umem_alloc_notifier,
+	.free_notifier		    = ib_umem_free_notifier,
+};
+
+static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp)
 {
-	struct ib_ucontext *ctx = umem_odp->umem.context;
 	struct ib_ucontext_per_mm *per_mm;
+	struct mmu_notifier *mn;
+	int ret;
 
-	/*
-	 * Generally speaking we expect only one or two per_mm in this list,
-	 * so no reason to optimize this search today.
-	 */
-	mutex_lock(&ctx->per_mm_list_lock);
-	list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
-		if (per_mm->mm == umem_odp->umem.owning_mm)
-			goto found;
+	umem_odp->umem.is_odp = 1;
+	if (!umem_odp->is_implicit_odp) {
+		size_t page_size = 1UL << umem_odp->page_shift;
+		size_t pages;
+
+		umem_odp->interval_tree.start =
+			ALIGN_DOWN(umem_odp->umem.address, page_size);
+		if (check_add_overflow(umem_odp->umem.address,
+				       umem_odp->umem.length,
+				       &umem_odp->interval_tree.last))
+			return -EOVERFLOW;
+		umem_odp->interval_tree.last =
+			ALIGN(umem_odp->interval_tree.last, page_size);
+		if (unlikely(umem_odp->interval_tree.last < page_size))
+			return -EOVERFLOW;
+
+		pages = (umem_odp->interval_tree.last -
+			 umem_odp->interval_tree.start) >>
+			umem_odp->page_shift;
+		if (!pages)
+			return -EINVAL;
+
+		/*
+		 * Note that the representation of the intervals in the
+		 * interval tree considers the ending point as contained in
+		 * the interval.
+		 */
+		umem_odp->interval_tree.last--;
+
+		umem_odp->page_list = kvcalloc(
+			pages, sizeof(*umem_odp->page_list), GFP_KERNEL);
+		if (!umem_odp->page_list)
+			return -ENOMEM;
+
+		umem_odp->dma_list = kvcalloc(
+			pages, sizeof(*umem_odp->dma_list), GFP_KERNEL);
+		if (!umem_odp->dma_list) {
+			ret = -ENOMEM;
+			goto out_page_list;
+		}
 	}
 
-	per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
-	if (IS_ERR(per_mm)) {
-		mutex_unlock(&ctx->per_mm_list_lock);
-		return PTR_ERR(per_mm);
+	mn = mmu_notifier_get(&ib_umem_notifiers, umem_odp->umem.owning_mm);
+	if (IS_ERR(mn)) {
+		ret = PTR_ERR(mn);
+		goto out_dma_list;
 	}
+	umem_odp->per_mm = per_mm =
+		container_of(mn, struct ib_ucontext_per_mm, mn);
 
-found:
-	umem_odp->per_mm = per_mm;
-	per_mm->odp_mrs_count++;
-	mutex_unlock(&ctx->per_mm_list_lock);
+	mutex_init(&umem_odp->umem_mutex);
+	init_completion(&umem_odp->notifier_completion);
+
+	if (!umem_odp->is_implicit_odp) {
+		down_write(&per_mm->umem_rwsem);
+		interval_tree_insert(&umem_odp->interval_tree,
+				     &per_mm->umem_tree);
+		up_write(&per_mm->umem_rwsem);
+	}
+	mmgrab(umem_odp->umem.owning_mm);
 
 	return 0;
-}
 
-static void free_per_mm(struct rcu_head *rcu)
-{
-	kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
+out_dma_list:
+	kvfree(umem_odp->dma_list);
+out_page_list:
+	kvfree(umem_odp->page_list);
+	return ret;
 }
 
-static void put_per_mm(struct ib_umem_odp *umem_odp)
+/**
+ * ib_umem_odp_alloc_implicit - Allocate a parent implicit ODP umem
+ *
+ * Implicit ODP umems do not have a VA range and do not have any page lists.
+ * They exist only to hold the per_mm reference to help the driver create
+ * children umems.
+ *
+ * @udata: udata from the syscall being used to create the umem
+ * @access: ib_reg_mr access flags
+ */
+struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
+					       int access)
 {
-	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
-	struct ib_ucontext *ctx = umem_odp->umem.context;
-	bool need_free;
-
-	mutex_lock(&ctx->per_mm_list_lock);
-	umem_odp->per_mm = NULL;
-	per_mm->odp_mrs_count--;
-	need_free = per_mm->odp_mrs_count == 0;
-	if (need_free)
-		list_del(&per_mm->ucontext_list);
-	mutex_unlock(&ctx->per_mm_list_lock);
-
-	if (!need_free)
-		return;
+	struct ib_ucontext *context =
+		container_of(udata, struct uverbs_attr_bundle, driver_udata)
+			->context;
+	struct ib_umem *umem;
+	struct ib_umem_odp *umem_odp;
+	int ret;
 
-	/*
-	 * NOTE! mmu_notifier_unregister() can happen between a start/end
-	 * callback, resulting in an start/end, and thus an unbalanced
-	 * lock. This doesn't really matter to us since we are about to kfree
-	 * the memory that holds the lock, however LOCKDEP doesn't like this.
-	 */
-	down_write(&per_mm->umem_rwsem);
-	per_mm->active = false;
-	up_write(&per_mm->umem_rwsem);
+	if (access & IB_ACCESS_HUGETLB)
+		return ERR_PTR(-EINVAL);
 
-	WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
-	mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
-	put_pid(per_mm->tgid);
-	mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
+	if (!context)
+		return ERR_PTR(-EIO);
+	if (WARN_ON_ONCE(!context->device->ops.invalidate_range))
+		return ERR_PTR(-EINVAL);
+
+	umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
+	if (!umem_odp)
+		return ERR_PTR(-ENOMEM);
+	umem = &umem_odp->umem;
+	umem->ibdev = context->device;
+	umem->writable = ib_access_writable(access);
+	umem->owning_mm = current->mm;
+	umem_odp->is_implicit_odp = 1;
+	umem_odp->page_shift = PAGE_SHIFT;
+
+	ret = ib_init_umem_odp(umem_odp);
+	if (ret) {
+		kfree(umem_odp);
+		return ERR_PTR(ret);
+	}
+	return umem_odp;
 }
+EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
 
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
-				      unsigned long addr, size_t size)
+/**
+ * ib_umem_odp_alloc_child - Allocate a child ODP umem under an implicit
+ *                           parent ODP umem
+ *
+ * @root: The parent umem enclosing the child. This must be allocated using
+ *        ib_alloc_implicit_odp_umem()
+ * @addr: The starting userspace VA
+ * @size: The length of the userspace VA
+ */
+struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
+					    unsigned long addr, size_t size)
 {
-	struct ib_ucontext_per_mm *per_mm = root->per_mm;
-	struct ib_ucontext *ctx = per_mm->context;
+	/*
+	 * Caller must ensure that root cannot be freed during the call to
+	 * ib_alloc_odp_umem.
+	 */
 	struct ib_umem_odp *odp_data;
 	struct ib_umem *umem;
-	int pages = size >> PAGE_SHIFT;
 	int ret;
 
+	if (WARN_ON(!root->is_implicit_odp))
+		return ERR_PTR(-EINVAL);
+
 	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
 	if (!odp_data)
 		return ERR_PTR(-ENOMEM);
 	umem = &odp_data->umem;
-	umem->context    = ctx;
+	umem->ibdev = root->umem.ibdev;
 	umem->length     = size;
 	umem->address    = addr;
-	odp_data->page_shift = PAGE_SHIFT;
 	umem->writable   = root->umem.writable;
-	umem->is_odp = 1;
-	odp_data->per_mm = per_mm;
-	umem->owning_mm  = per_mm->mm;
-	mmgrab(umem->owning_mm);
-
-	mutex_init(&odp_data->umem_mutex);
-	init_completion(&odp_data->notifier_completion);
-
-	odp_data->page_list =
-		vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
-	if (!odp_data->page_list) {
-		ret = -ENOMEM;
-		goto out_odp_data;
-	}
+	umem->owning_mm  = root->umem.owning_mm;
+	odp_data->page_shift = PAGE_SHIFT;
 
-	odp_data->dma_list =
-		vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
-	if (!odp_data->dma_list) {
-		ret = -ENOMEM;
-		goto out_page_list;
+	ret = ib_init_umem_odp(odp_data);
+	if (ret) {
+		kfree(odp_data);
+		return ERR_PTR(ret);
 	}
-
-	/*
-	 * Caller must ensure that the umem_odp that the per_mm came from
-	 * cannot be freed during the call to ib_alloc_odp_umem.
-	 */
-	mutex_lock(&ctx->per_mm_list_lock);
-	per_mm->odp_mrs_count++;
-	mutex_unlock(&ctx->per_mm_list_lock);
-	add_umem_to_per_mm(odp_data);
-
 	return odp_data;
-
-out_page_list:
-	vfree(odp_data->page_list);
-out_odp_data:
-	mmdrop(umem->owning_mm);
-	kfree(odp_data);
-	return ERR_PTR(ret);
 }
-EXPORT_SYMBOL(ib_alloc_odp_umem);
+EXPORT_SYMBOL(ib_umem_odp_alloc_child);
 
-int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
+/**
+ * ib_umem_odp_get - Create a umem_odp for a userspace va
+ *
+ * @udata: userspace context to pin memory for
+ * @addr: userspace virtual address to start at
+ * @size: length of region to pin
+ * @access: IB_ACCESS_xxx flags for memory being pinned
+ *
+ * The driver should use when the access flags indicate ODP memory. It avoids
+ * pinning, instead, stores the mm for future page fault handling in
+ * conjunction with MMU notifiers.
+ */
+struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
+				    size_t size, int access)
 {
-	struct ib_umem *umem = &umem_odp->umem;
-	/*
-	 * NOTE: This must called in a process context where umem->owning_mm
-	 * == current->mm
-	 */
-	struct mm_struct *mm = umem->owning_mm;
-	int ret_val;
+	struct ib_umem_odp *umem_odp;
+	struct ib_ucontext *context;
+	struct mm_struct *mm;
+	int ret;
+
+	if (!udata)
+		return ERR_PTR(-EIO);
+
+	context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
+			  ->context;
+	if (!context)
+		return ERR_PTR(-EIO);
+
+	if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)) ||
+	    WARN_ON_ONCE(!context->device->ops.invalidate_range))
+		return ERR_PTR(-EINVAL);
+
+	umem_odp = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
+	if (!umem_odp)
+		return ERR_PTR(-ENOMEM);
+
+	umem_odp->umem.ibdev = context->device;
+	umem_odp->umem.length = size;
+	umem_odp->umem.address = addr;
+	umem_odp->umem.writable = ib_access_writable(access);
+	umem_odp->umem.owning_mm = mm = current->mm;
 
 	umem_odp->page_shift = PAGE_SHIFT;
 	if (access & IB_ACCESS_HUGETLB) {
@@ -410,63 +421,63 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 		vma = find_vma(mm, ib_umem_start(umem_odp));
 		if (!vma || !is_vm_hugetlb_page(vma)) {
 			up_read(&mm->mmap_sem);
-			return -EINVAL;
+			ret = -EINVAL;
+			goto err_free;
 		}
 		h = hstate_vma(vma);
 		umem_odp->page_shift = huge_page_shift(h);
 		up_read(&mm->mmap_sem);
 	}
 
-	mutex_init(&umem_odp->umem_mutex);
-
-	init_completion(&umem_odp->notifier_completion);
-
-	if (ib_umem_odp_num_pages(umem_odp)) {
-		umem_odp->page_list =
-			vzalloc(array_size(sizeof(*umem_odp->page_list),
-					   ib_umem_odp_num_pages(umem_odp)));
-		if (!umem_odp->page_list)
-			return -ENOMEM;
-
-		umem_odp->dma_list =
-			vzalloc(array_size(sizeof(*umem_odp->dma_list),
-					   ib_umem_odp_num_pages(umem_odp)));
-		if (!umem_odp->dma_list) {
-			ret_val = -ENOMEM;
-			goto out_page_list;
-		}
-	}
-
-	ret_val = get_per_mm(umem_odp);
-	if (ret_val)
-		goto out_dma_list;
-	add_umem_to_per_mm(umem_odp);
-
-	return 0;
+	ret = ib_init_umem_odp(umem_odp);
+	if (ret)
+		goto err_free;
+	return umem_odp;
 
-out_dma_list:
-	vfree(umem_odp->dma_list);
-out_page_list:
-	vfree(umem_odp->page_list);
-	return ret_val;
+err_free:
+	kfree(umem_odp);
+	return ERR_PTR(ret);
 }
+EXPORT_SYMBOL(ib_umem_odp_get);
 
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 {
+	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
+
 	/*
 	 * Ensure that no more pages are mapped in the umem.
 	 *
 	 * It is the driver's responsibility to ensure, before calling us,
 	 * that the hardware will not attempt to access the MR any more.
 	 */
-	ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
-				    ib_umem_end(umem_odp));
+	if (!umem_odp->is_implicit_odp) {
+		ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
+					    ib_umem_end(umem_odp));
+		kvfree(umem_odp->dma_list);
+		kvfree(umem_odp->page_list);
+	}
 
-	remove_umem_from_per_mm(umem_odp);
-	put_per_mm(umem_odp);
-	vfree(umem_odp->dma_list);
-	vfree(umem_odp->page_list);
+	down_write(&per_mm->umem_rwsem);
+	if (!umem_odp->is_implicit_odp) {
+		interval_tree_remove(&umem_odp->interval_tree,
+				     &per_mm->umem_tree);
+		complete_all(&umem_odp->notifier_completion);
+	}
+	/*
+	 * NOTE! mmu_notifier_unregister() can happen between a start/end
+	 * callback, resulting in a missing end, and thus an unbalanced
+	 * lock. This doesn't really matter to us since we are about to kfree
+	 * the memory that holds the lock, however LOCKDEP doesn't like this.
+	 * Thus we call the mmu_notifier_put under the rwsem and test the
+	 * internal users count to reliably see if we are past this point.
+	 */
+	mmu_notifier_put(&per_mm->mn);
+	up_write(&per_mm->umem_rwsem);
+
+	mmdrop(umem_odp->umem.owning_mm);
+	kfree(umem_odp);
 }
+EXPORT_SYMBOL(ib_umem_odp_release);
 
 /*
  * Map for DMA and insert a single page into the on-demand paging page tables.
@@ -493,8 +504,7 @@ static int ib_umem_odp_map_dma_single_page(
 		u64 access_mask,
 		unsigned long current_seq)
 {
-	struct ib_ucontext *context = umem_odp->umem.context;
-	struct ib_device *dev = context->device;
+	struct ib_device *dev = umem_odp->umem.ibdev;
 	dma_addr_t dma_addr;
 	int remove_existing_mapping = 0;
 	int ret = 0;
@@ -534,7 +544,7 @@ out:
 
 	if (remove_existing_mapping) {
 		ib_umem_notifier_start_account(umem_odp);
-		context->invalidate_range(
+		dev->ops.invalidate_range(
 			umem_odp,
 			ib_umem_start(umem_odp) +
 				(page_index << umem_odp->page_shift),
@@ -707,7 +717,7 @@ void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
 {
 	int idx;
 	u64 addr;
-	struct ib_device *dev = umem_odp->umem.context->device;
+	struct ib_device *dev = umem_odp->umem.ibdev;
 
 	virt = max_t(u64, virt, ib_umem_start(umem_odp));
 	bound = min_t(u64, bound, ib_umem_end(umem_odp));
@@ -761,35 +771,21 @@ int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
 				  void *cookie)
 {
 	int ret_val = 0;
-	struct umem_odp_node *node, *next;
+	struct interval_tree_node *node, *next;
 	struct ib_umem_odp *umem;
 
 	if (unlikely(start == last))
 		return ret_val;
 
-	for (node = rbt_ib_umem_iter_first(root, start, last - 1);
+	for (node = interval_tree_iter_first(root, start, last - 1);
 			node; node = next) {
 		/* TODO move the blockable decision up to the callback */
 		if (!blockable)
 			return -EAGAIN;
-		next = rbt_ib_umem_iter_next(node, start, last - 1);
+		next = interval_tree_iter_next(node, start, last - 1);
 		umem = container_of(node, struct ib_umem_odp, interval_tree);
 		ret_val = cb(umem, start, last, cookie) || ret_val;
 	}
 
 	return ret_val;
 }
-EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
-
-struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
-				       u64 addr, u64 length)
-{
-	struct umem_odp_node *node;
-
-	node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
-	if (node)
-		return container_of(node, struct ib_umem_odp, interval_tree);
-	return NULL;
-
-}
-EXPORT_SYMBOL(rbt_ib_umem_lookup);
diff --git a/drivers/infiniband/core/uverbs_cmd.c b/drivers/infiniband/core/uverbs_cmd.c
index 7ddd0e5bc6b3..7c10dfe417a4 100644
--- a/drivers/infiniband/core/uverbs_cmd.c
+++ b/drivers/infiniband/core/uverbs_cmd.c
@@ -252,9 +252,6 @@ static int ib_uverbs_get_context(struct uverbs_attr_bundle *attrs)
 	ucontext->closing = false;
 	ucontext->cleanup_retryable = false;
 
-	mutex_init(&ucontext->per_mm_list_lock);
-	INIT_LIST_HEAD(&ucontext->per_mm_list);
-
 	ret = get_unused_fd_flags(O_CLOEXEC);
 	if (ret < 0)
 		goto err_free;
@@ -275,8 +272,6 @@ static int ib_uverbs_get_context(struct uverbs_attr_bundle *attrs)
 	ret = ib_dev->ops.alloc_ucontext(ucontext, &attrs->driver_udata);
 	if (ret)
 		goto err_file;
-	if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
-		ucontext->invalidate_range = NULL;
 
 	rdma_restrack_uadd(&ucontext->res);
 
diff --git a/drivers/infiniband/core/uverbs_main.c b/drivers/infiniband/core/uverbs_main.c
index 11c13c1381cf..e369ac0d6f51 100644
--- a/drivers/infiniband/core/uverbs_main.c
+++ b/drivers/infiniband/core/uverbs_main.c
@@ -1487,6 +1487,7 @@ static void __exit ib_uverbs_cleanup(void)
 				 IB_UVERBS_NUM_FIXED_MINOR);
 	unregister_chrdev_region(dynamic_uverbs_dev,
 				 IB_UVERBS_NUM_DYNAMIC_MINOR);
+	mmu_notifier_synchronize();
 }
 
 module_init(ib_uverbs_init);
diff --git a/drivers/infiniband/hw/mlx5/main.c b/drivers/infiniband/hw/mlx5/main.c
index 4e9f1507ffd9..bface798ee59 100644
--- a/drivers/infiniband/hw/mlx5/main.c
+++ b/drivers/infiniband/hw/mlx5/main.c
@@ -1867,10 +1867,6 @@ static int mlx5_ib_alloc_ucontext(struct ib_ucontext *uctx,
 	if (err)
 		goto out_sys_pages;
 
-	if (ibdev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING)
-		context->ibucontext.invalidate_range =
-			&mlx5_ib_invalidate_range;
-
 	if (req.flags & MLX5_IB_ALLOC_UCTX_DEVX) {
 		err = mlx5_ib_devx_create(dev, true);
 		if (err < 0)
@@ -1999,11 +1995,6 @@ static void mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext)
 	struct mlx5_ib_dev *dev = to_mdev(ibcontext->device);
 	struct mlx5_bfreg_info *bfregi;
 
-	/* All umem's must be destroyed before destroying the ucontext. */
-	mutex_lock(&ibcontext->per_mm_list_lock);
-	WARN_ON(!list_empty(&ibcontext->per_mm_list));
-	mutex_unlock(&ibcontext->per_mm_list_lock);
-
 	bfregi = &context->bfregi;
 	mlx5_ib_dealloc_transport_domain(dev, context->tdn, context->devx_uid);
 
diff --git a/drivers/infiniband/hw/mlx5/mem.c b/drivers/infiniband/hw/mlx5/mem.c
index a40e0abf2338..b5aece786b36 100644
--- a/drivers/infiniband/hw/mlx5/mem.c
+++ b/drivers/infiniband/hw/mlx5/mem.c
@@ -56,19 +56,6 @@ void mlx5_ib_cont_pages(struct ib_umem *umem, u64 addr,
 	struct scatterlist *sg;
 	int entry;
 
-	if (umem->is_odp) {
-		struct ib_umem_odp *odp = to_ib_umem_odp(umem);
-		unsigned int page_shift = odp->page_shift;
-
-		*ncont = ib_umem_odp_num_pages(odp);
-		*count = *ncont << (page_shift - PAGE_SHIFT);
-		*shift = page_shift;
-		if (order)
-			*order = ilog2(roundup_pow_of_two(*ncont));
-
-		return;
-	}
-
 	addr = addr >> PAGE_SHIFT;
 	tmp = (unsigned long)addr;
 	m = find_first_bit(&tmp, BITS_PER_LONG);
diff --git a/drivers/infiniband/hw/mlx5/mr.c b/drivers/infiniband/hw/mlx5/mr.c
index 3401f5f6792e..1eff031ef048 100644
--- a/drivers/infiniband/hw/mlx5/mr.c
+++ b/drivers/infiniband/hw/mlx5/mr.c
@@ -784,19 +784,37 @@ static int mr_umem_get(struct mlx5_ib_dev *dev, struct ib_udata *udata,
 		       int *ncont, int *order)
 {
 	struct ib_umem *u;
-	int err;
 
 	*umem = NULL;
 
-	u = ib_umem_get(udata, start, length, access_flags, 0);
-	err = PTR_ERR_OR_ZERO(u);
-	if (err) {
-		mlx5_ib_dbg(dev, "umem get failed (%d)\n", err);
-		return err;
+	if (access_flags & IB_ACCESS_ON_DEMAND) {
+		struct ib_umem_odp *odp;
+
+		odp = ib_umem_odp_get(udata, start, length, access_flags);
+		if (IS_ERR(odp)) {
+			mlx5_ib_dbg(dev, "umem get failed (%ld)\n",
+				    PTR_ERR(odp));
+			return PTR_ERR(odp);
+		}
+
+		u = &odp->umem;
+
+		*page_shift = odp->page_shift;
+		*ncont = ib_umem_odp_num_pages(odp);
+		*npages = *ncont << (*page_shift - PAGE_SHIFT);
+		if (order)
+			*order = ilog2(roundup_pow_of_two(*ncont));
+	} else {
+		u = ib_umem_get(udata, start, length, access_flags, 0);
+		if (IS_ERR(u)) {
+			mlx5_ib_dbg(dev, "umem get failed (%ld)\n", PTR_ERR(u));
+			return PTR_ERR(u);
+		}
+
+		mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
+				   page_shift, ncont, order);
 	}
 
-	mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
-			   page_shift, ncont, order);
 	if (!*npages) {
 		mlx5_ib_warn(dev, "avoid zero region\n");
 		ib_umem_release(u);
@@ -1599,7 +1617,7 @@ static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 		/* Wait for all running page-fault handlers to finish. */
 		synchronize_srcu(&dev->mr_srcu);
 		/* Destroy all page mappings */
-		if (umem_odp->page_list)
+		if (!umem_odp->is_implicit_odp)
 			mlx5_ib_invalidate_range(umem_odp,
 						 ib_umem_start(umem_odp),
 						 ib_umem_end(umem_odp));
@@ -1610,7 +1628,7 @@ static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 		 * so that there will not be any invalidations in
 		 * flight, looking at the *mr struct.
 		 */
-		ib_umem_release(umem);
+		ib_umem_odp_release(umem_odp);
 		atomic_sub(npages, &dev->mdev->priv.reg_pages);
 
 		/* Avoid double-freeing the umem. */
diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c
index 0a59912a4cef..dd26e7acb37e 100644
--- a/drivers/infiniband/hw/mlx5/odp.c
+++ b/drivers/infiniband/hw/mlx5/odp.c
@@ -184,7 +184,7 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
 	for (i = 0; i < nentries; i++, pklm++) {
 		pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
 		va = (offset + i) * MLX5_IMR_MTT_SIZE;
-		if (odp && odp->umem.address == va) {
+		if (odp && ib_umem_start(odp) == va) {
 			struct mlx5_ib_mr *mtt = odp->private;
 
 			pklm->key = cpu_to_be32(mtt->ibmr.lkey);
@@ -206,7 +206,7 @@ static void mr_leaf_free_action(struct work_struct *work)
 	mr->parent = NULL;
 	synchronize_srcu(&mr->dev->mr_srcu);
 
-	ib_umem_release(&odp->umem);
+	ib_umem_odp_release(odp);
 	if (imr->live)
 		mlx5_ib_update_xlt(imr, idx, 1, 0,
 				   MLX5_IB_UPD_XLT_INDIRECT |
@@ -386,7 +386,7 @@ static void mlx5_ib_page_fault_resume(struct mlx5_ib_dev *dev,
 }
 
 static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd,
-					    struct ib_umem *umem,
+					    struct ib_umem_odp *umem_odp,
 					    bool ksm, int access_flags)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
@@ -404,7 +404,7 @@ static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd,
 	mr->dev = dev;
 	mr->access_flags = access_flags;
 	mr->mmkey.iova = 0;
-	mr->umem = umem;
+	mr->umem = &umem_odp->umem;
 
 	if (ksm) {
 		err = mlx5_ib_update_xlt(mr, 0,
@@ -464,18 +464,17 @@ next_mr:
 		if (nentries)
 			nentries++;
 	} else {
-		odp = ib_alloc_odp_umem(odp_mr, addr,
-					MLX5_IMR_MTT_SIZE);
+		odp = ib_umem_odp_alloc_child(odp_mr, addr, MLX5_IMR_MTT_SIZE);
 		if (IS_ERR(odp)) {
 			mutex_unlock(&odp_mr->umem_mutex);
 			return ERR_CAST(odp);
 		}
 
-		mtt = implicit_mr_alloc(mr->ibmr.pd, &odp->umem, 0,
+		mtt = implicit_mr_alloc(mr->ibmr.pd, odp, 0,
 					mr->access_flags);
 		if (IS_ERR(mtt)) {
 			mutex_unlock(&odp_mr->umem_mutex);
-			ib_umem_release(&odp->umem);
+			ib_umem_odp_release(odp);
 			return ERR_CAST(mtt);
 		}
 
@@ -497,7 +496,7 @@ next_mr:
 	addr += MLX5_IMR_MTT_SIZE;
 	if (unlikely(addr < io_virt + bcnt)) {
 		odp = odp_next(odp);
-		if (odp && odp->umem.address != addr)
+		if (odp && ib_umem_start(odp) != addr)
 			odp = NULL;
 		goto next_mr;
 	}
@@ -521,19 +520,19 @@ struct mlx5_ib_mr *mlx5_ib_alloc_implicit_mr(struct mlx5_ib_pd *pd,
 					     int access_flags)
 {
 	struct mlx5_ib_mr *imr;
-	struct ib_umem *umem;
+	struct ib_umem_odp *umem_odp;
 
-	umem = ib_umem_get(udata, 0, 0, access_flags, 0);
-	if (IS_ERR(umem))
-		return ERR_CAST(umem);
+	umem_odp = ib_umem_odp_alloc_implicit(udata, access_flags);
+	if (IS_ERR(umem_odp))
+		return ERR_CAST(umem_odp);
 
-	imr = implicit_mr_alloc(&pd->ibpd, umem, 1, access_flags);
+	imr = implicit_mr_alloc(&pd->ibpd, umem_odp, 1, access_flags);
 	if (IS_ERR(imr)) {
-		ib_umem_release(umem);
+		ib_umem_odp_release(umem_odp);
 		return ERR_CAST(imr);
 	}
 
-	imr->umem = umem;
+	imr->umem = &umem_odp->umem;
 	init_waitqueue_head(&imr->q_leaf_free);
 	atomic_set(&imr->num_leaf_free, 0);
 	atomic_set(&imr->num_pending_prefetch, 0);
@@ -541,34 +540,31 @@ struct mlx5_ib_mr *mlx5_ib_alloc_implicit_mr(struct mlx5_ib_pd *pd,
 	return imr;
 }
 
-static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end,
-			void *cookie)
+void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr)
 {
-	struct mlx5_ib_mr *mr = umem_odp->private, *imr = cookie;
-
-	if (mr->parent != imr)
-		return 0;
-
-	ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
-				    ib_umem_end(umem_odp));
+	struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(imr);
+	struct rb_node *node;
 
-	if (umem_odp->dying)
-		return 0;
+	down_read(&per_mm->umem_rwsem);
+	for (node = rb_first_cached(&per_mm->umem_tree); node;
+	     node = rb_next(node)) {
+		struct ib_umem_odp *umem_odp =
+			rb_entry(node, struct ib_umem_odp, interval_tree.rb);
+		struct mlx5_ib_mr *mr = umem_odp->private;
 
-	WRITE_ONCE(umem_odp->dying, 1);
-	atomic_inc(&imr->num_leaf_free);
-	schedule_work(&umem_odp->work);
+		if (mr->parent != imr)
+			continue;
 
-	return 0;
-}
+		ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
+					    ib_umem_end(umem_odp));
 
-void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr)
-{
-	struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(imr);
+		if (umem_odp->dying)
+			continue;
 
-	down_read(&per_mm->umem_rwsem);
-	rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0, ULLONG_MAX,
-				      mr_leaf_free, true, imr);
+		WRITE_ONCE(umem_odp->dying, 1);
+		atomic_inc(&imr->num_leaf_free);
+		schedule_work(&umem_odp->work);
+	}
 	up_read(&per_mm->umem_rwsem);
 
 	wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free));
@@ -589,7 +585,7 @@ static int pagefault_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
 	struct ib_umem_odp *odp;
 	size_t size;
 
-	if (!odp_mr->page_list) {
+	if (odp_mr->is_implicit_odp) {
 		odp = implicit_mr_get_data(mr, io_virt, bcnt);
 
 		if (IS_ERR(odp))
@@ -607,7 +603,7 @@ next_mr:
 	start_idx = (io_virt - (mr->mmkey.iova & page_mask)) >> page_shift;
 	access_mask = ODP_READ_ALLOWED_BIT;
 
-	if (prefetch && !downgrade && !mr->umem->writable) {
+	if (prefetch && !downgrade && !odp->umem.writable) {
 		/* prefetch with write-access must
 		 * be supported by the MR
 		 */
@@ -615,7 +611,7 @@ next_mr:
 		goto out;
 	}
 
-	if (mr->umem->writable && !downgrade)
+	if (odp->umem.writable && !downgrade)
 		access_mask |= ODP_WRITE_ALLOWED_BIT;
 
 	current_seq = READ_ONCE(odp->notifiers_seq);
@@ -625,8 +621,8 @@ next_mr:
 	 */
 	smp_rmb();
 
-	ret = ib_umem_odp_map_dma_pages(to_ib_umem_odp(mr->umem), io_virt, size,
-					access_mask, current_seq);
+	ret = ib_umem_odp_map_dma_pages(odp, io_virt, size, access_mask,
+					current_seq);
 
 	if (ret < 0)
 		goto out;
@@ -634,8 +630,7 @@ next_mr:
 	np = ret;
 
 	mutex_lock(&odp->umem_mutex);
-	if (!ib_umem_mmu_notifier_retry(to_ib_umem_odp(mr->umem),
-					current_seq)) {
+	if (!ib_umem_mmu_notifier_retry(odp, current_seq)) {
 		/*
 		 * No need to check whether the MTTs really belong to
 		 * this MR, since ib_umem_odp_map_dma_pages already
@@ -668,7 +663,7 @@ next_mr:
 
 		io_virt += size;
 		next = odp_next(odp);
-		if (unlikely(!next || next->umem.address != io_virt)) {
+		if (unlikely(!next || ib_umem_start(next) != io_virt)) {
 			mlx5_ib_dbg(dev, "next implicit leaf removed at 0x%llx. got %p\n",
 				    io_virt, next);
 			return -EAGAIN;
@@ -1618,6 +1613,7 @@ void mlx5_odp_init_mr_cache_entry(struct mlx5_cache_ent *ent)
 
 static const struct ib_device_ops mlx5_ib_dev_odp_ops = {
 	.advise_mr = mlx5_ib_advise_mr,
+	.invalidate_range = mlx5_ib_invalidate_range,
 };
 
 int mlx5_ib_odp_init_one(struct mlx5_ib_dev *dev)
diff --git a/drivers/misc/sgi-gru/grufile.c b/drivers/misc/sgi-gru/grufile.c
index a2a142ae087b..9d042310214f 100644
--- a/drivers/misc/sgi-gru/grufile.c
+++ b/drivers/misc/sgi-gru/grufile.c
@@ -573,6 +573,7 @@ static void __exit gru_exit(void)
 	gru_free_tables();
 	misc_deregister(&gru_miscdev);
 	gru_proc_exit();
+	mmu_notifier_synchronize();
 }
 
 static const struct file_operations gru_fops = {
diff --git a/drivers/misc/sgi-gru/grutables.h b/drivers/misc/sgi-gru/grutables.h
index 438191c22057..a7e44b2eb413 100644
--- a/drivers/misc/sgi-gru/grutables.h
+++ b/drivers/misc/sgi-gru/grutables.h
@@ -307,10 +307,8 @@ struct gru_mm_tracker {				/* pack to reduce size */
 
 struct gru_mm_struct {
 	struct mmu_notifier	ms_notifier;
-	atomic_t		ms_refcnt;
 	spinlock_t		ms_asid_lock;	/* protects ASID assignment */
 	atomic_t		ms_range_active;/* num range_invals active */
-	char			ms_released;
 	wait_queue_head_t	ms_wait_queue;
 	DECLARE_BITMAP(ms_asidmap, GRU_MAX_GRUS);
 	struct gru_mm_tracker	ms_asids[GRU_MAX_GRUS];
diff --git a/drivers/misc/sgi-gru/grutlbpurge.c b/drivers/misc/sgi-gru/grutlbpurge.c
index 59ba0adf23ce..10921cd2608d 100644
--- a/drivers/misc/sgi-gru/grutlbpurge.c
+++ b/drivers/misc/sgi-gru/grutlbpurge.c
@@ -235,83 +235,47 @@ static void gru_invalidate_range_end(struct mmu_notifier *mn,
 		gms, range->start, range->end);
 }
 
-static void gru_release(struct mmu_notifier *mn, struct mm_struct *mm)
+static struct mmu_notifier *gru_alloc_notifier(struct mm_struct *mm)
 {
-	struct gru_mm_struct *gms = container_of(mn, struct gru_mm_struct,
-						 ms_notifier);
+	struct gru_mm_struct *gms;
+
+	gms = kzalloc(sizeof(*gms), GFP_KERNEL);
+	if (!gms)
+		return ERR_PTR(-ENOMEM);
+	STAT(gms_alloc);
+	spin_lock_init(&gms->ms_asid_lock);
+	init_waitqueue_head(&gms->ms_wait_queue);
 
-	gms->ms_released = 1;
-	gru_dbg(grudev, "gms %p\n", gms);
+	return &gms->ms_notifier;
 }
 
+static void gru_free_notifier(struct mmu_notifier *mn)
+{
+	kfree(container_of(mn, struct gru_mm_struct, ms_notifier));
+	STAT(gms_free);
+}
 
 static const struct mmu_notifier_ops gru_mmuops = {
 	.invalidate_range_start	= gru_invalidate_range_start,
 	.invalidate_range_end	= gru_invalidate_range_end,
-	.release		= gru_release,
+	.alloc_notifier		= gru_alloc_notifier,
+	.free_notifier		= gru_free_notifier,
 };
 
-/* Move this to the basic mmu_notifier file. But for now... */
-static struct mmu_notifier *mmu_find_ops(struct mm_struct *mm,
-			const struct mmu_notifier_ops *ops)
-{
-	struct mmu_notifier *mn, *gru_mn = NULL;
-
-	if (mm->mmu_notifier_mm) {
-		rcu_read_lock();
-		hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list,
-					 hlist)
-		    if (mn->ops == ops) {
-			gru_mn = mn;
-			break;
-		}
-		rcu_read_unlock();
-	}
-	return gru_mn;
-}
-
 struct gru_mm_struct *gru_register_mmu_notifier(void)
 {
-	struct gru_mm_struct *gms;
 	struct mmu_notifier *mn;
-	int err;
-
-	mn = mmu_find_ops(current->mm, &gru_mmuops);
-	if (mn) {
-		gms = container_of(mn, struct gru_mm_struct, ms_notifier);
-		atomic_inc(&gms->ms_refcnt);
-	} else {
-		gms = kzalloc(sizeof(*gms), GFP_KERNEL);
-		if (!gms)
-			return ERR_PTR(-ENOMEM);
-		STAT(gms_alloc);
-		spin_lock_init(&gms->ms_asid_lock);
-		gms->ms_notifier.ops = &gru_mmuops;
-		atomic_set(&gms->ms_refcnt, 1);
-		init_waitqueue_head(&gms->ms_wait_queue);
-		err = __mmu_notifier_register(&gms->ms_notifier, current->mm);
-		if (err)
-			goto error;
-	}
-	if (gms)
-		gru_dbg(grudev, "gms %p, refcnt %d\n", gms,
-			atomic_read(&gms->ms_refcnt));
-	return gms;
-error:
-	kfree(gms);
-	return ERR_PTR(err);
+
+	mn = mmu_notifier_get_locked(&gru_mmuops, current->mm);
+	if (IS_ERR(mn))
+		return ERR_CAST(mn);
+
+	return container_of(mn, struct gru_mm_struct, ms_notifier);
 }
 
 void gru_drop_mmu_notifier(struct gru_mm_struct *gms)
 {
-	gru_dbg(grudev, "gms %p, refcnt %d, released %d\n", gms,
-		atomic_read(&gms->ms_refcnt), gms->ms_released);
-	if (atomic_dec_return(&gms->ms_refcnt) == 0) {
-		if (!gms->ms_released)
-			mmu_notifier_unregister(&gms->ms_notifier, current->mm);
-		kfree(gms);
-		STAT(gms_free);
-	}
+	mmu_notifier_put(&gms->ms_notifier);
 }
 
 /*
diff --git a/drivers/nvdimm/Kconfig b/drivers/nvdimm/Kconfig
index a5fde15e91d3..36af7af6b7cf 100644
--- a/drivers/nvdimm/Kconfig
+++ b/drivers/nvdimm/Kconfig
@@ -118,4 +118,16 @@ config NVDIMM_KEYS
 	depends on ENCRYPTED_KEYS
 	depends on (LIBNVDIMM=ENCRYPTED_KEYS) || LIBNVDIMM=m
 
+config NVDIMM_TEST_BUILD
+	tristate "Build the unit test core"
+	depends on m
+	depends on COMPILE_TEST && X86_64
+	default m if COMPILE_TEST
+	help
+	  Build the core of the unit test infrastructure. The result of
+	  this build is non-functional for unit test execution, but it
+	  otherwise helps catch build errors induced by changes to the
+	  core devm_memremap_pages() implementation and other
+	  infrastructure.
+
 endif
diff --git a/drivers/nvdimm/Makefile b/drivers/nvdimm/Makefile
index cefe233e0b52..29203f3d3069 100644
--- a/drivers/nvdimm/Makefile
+++ b/drivers/nvdimm/Makefile
@@ -29,3 +29,7 @@ libnvdimm-$(CONFIG_BTT) += btt_devs.o
 libnvdimm-$(CONFIG_NVDIMM_PFN) += pfn_devs.o
 libnvdimm-$(CONFIG_NVDIMM_DAX) += dax_devs.o
 libnvdimm-$(CONFIG_NVDIMM_KEYS) += security.o
+
+TOOLS := ../../tools
+TEST_SRC := $(TOOLS)/testing/nvdimm/test
+obj-$(CONFIG_NVDIMM_TEST_BUILD) += $(TEST_SRC)/iomap.o
diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
index 731642e0f5a0..bf43d1d60059 100644
--- a/fs/proc/task_mmu.c
+++ b/fs/proc/task_mmu.c
@@ -1,5 +1,5 @@
 // SPDX-License-Identifier: GPL-2.0
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/vmacache.h>
 #include <linux/hugetlb.h>
 #include <linux/huge_mm.h>
@@ -513,7 +513,9 @@ static int smaps_pte_hole(unsigned long addr, unsigned long end,
 
 	return 0;
 }
-#endif
+#else
+#define smaps_pte_hole		NULL
+#endif /* CONFIG_SHMEM */
 
 static void smaps_pte_entry(pte_t *pte, unsigned long addr,
 		struct mm_walk *walk)
@@ -729,21 +731,24 @@ static int smaps_hugetlb_range(pte_t *pte, unsigned long hmask,
 	}
 	return 0;
 }
+#else
+#define smaps_hugetlb_range	NULL
 #endif /* HUGETLB_PAGE */
 
+static const struct mm_walk_ops smaps_walk_ops = {
+	.pmd_entry		= smaps_pte_range,
+	.hugetlb_entry		= smaps_hugetlb_range,
+};
+
+static const struct mm_walk_ops smaps_shmem_walk_ops = {
+	.pmd_entry		= smaps_pte_range,
+	.hugetlb_entry		= smaps_hugetlb_range,
+	.pte_hole		= smaps_pte_hole,
+};
+
 static void smap_gather_stats(struct vm_area_struct *vma,
 			     struct mem_size_stats *mss)
 {
-	struct mm_walk smaps_walk = {
-		.pmd_entry = smaps_pte_range,
-#ifdef CONFIG_HUGETLB_PAGE
-		.hugetlb_entry = smaps_hugetlb_range,
-#endif
-		.mm = vma->vm_mm,
-	};
-
-	smaps_walk.private = mss;
-
 #ifdef CONFIG_SHMEM
 	/* In case of smaps_rollup, reset the value from previous vma */
 	mss->check_shmem_swap = false;
@@ -765,12 +770,13 @@ static void smap_gather_stats(struct vm_area_struct *vma,
 			mss->swap += shmem_swapped;
 		} else {
 			mss->check_shmem_swap = true;
-			smaps_walk.pte_hole = smaps_pte_hole;
+			walk_page_vma(vma, &smaps_shmem_walk_ops, mss);
+			return;
 		}
 	}
 #endif
 	/* mmap_sem is held in m_start */
-	walk_page_vma(vma, &smaps_walk);
+	walk_page_vma(vma, &smaps_walk_ops, mss);
 }
 
 #define SEQ_PUT_DEC(str, val) \
@@ -1118,6 +1124,11 @@ static int clear_refs_test_walk(unsigned long start, unsigned long end,
 	return 0;
 }
 
+static const struct mm_walk_ops clear_refs_walk_ops = {
+	.pmd_entry		= clear_refs_pte_range,
+	.test_walk		= clear_refs_test_walk,
+};
+
 static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 				size_t count, loff_t *ppos)
 {
@@ -1151,12 +1162,6 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 		struct clear_refs_private cp = {
 			.type = type,
 		};
-		struct mm_walk clear_refs_walk = {
-			.pmd_entry = clear_refs_pte_range,
-			.test_walk = clear_refs_test_walk,
-			.mm = mm,
-			.private = &cp,
-		};
 
 		if (type == CLEAR_REFS_MM_HIWATER_RSS) {
 			if (down_write_killable(&mm->mmap_sem)) {
@@ -1217,7 +1222,8 @@ static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 						0, NULL, mm, 0, -1UL);
 			mmu_notifier_invalidate_range_start(&range);
 		}
-		walk_page_range(0, mm->highest_vm_end, &clear_refs_walk);
+		walk_page_range(mm, 0, mm->highest_vm_end, &clear_refs_walk_ops,
+				&cp);
 		if (type == CLEAR_REFS_SOFT_DIRTY)
 			mmu_notifier_invalidate_range_end(&range);
 		tlb_finish_mmu(&tlb, 0, -1);
@@ -1489,8 +1495,16 @@ static int pagemap_hugetlb_range(pte_t *ptep, unsigned long hmask,
 
 	return err;
 }
+#else
+#define pagemap_hugetlb_range	NULL
 #endif /* HUGETLB_PAGE */
 
+static const struct mm_walk_ops pagemap_ops = {
+	.pmd_entry	= pagemap_pmd_range,
+	.pte_hole	= pagemap_pte_hole,
+	.hugetlb_entry	= pagemap_hugetlb_range,
+};
+
 /*
  * /proc/pid/pagemap - an array mapping virtual pages to pfns
  *
@@ -1522,7 +1536,6 @@ static ssize_t pagemap_read(struct file *file, char __user *buf,
 {
 	struct mm_struct *mm = file->private_data;
 	struct pagemapread pm;
-	struct mm_walk pagemap_walk = {};
 	unsigned long src;
 	unsigned long svpfn;
 	unsigned long start_vaddr;
@@ -1550,14 +1563,6 @@ static ssize_t pagemap_read(struct file *file, char __user *buf,
 	if (!pm.buffer)
 		goto out_mm;
 
-	pagemap_walk.pmd_entry = pagemap_pmd_range;
-	pagemap_walk.pte_hole = pagemap_pte_hole;
-#ifdef CONFIG_HUGETLB_PAGE
-	pagemap_walk.hugetlb_entry = pagemap_hugetlb_range;
-#endif
-	pagemap_walk.mm = mm;
-	pagemap_walk.private = &pm;
-
 	src = *ppos;
 	svpfn = src / PM_ENTRY_BYTES;
 	start_vaddr = svpfn << PAGE_SHIFT;
@@ -1586,7 +1591,7 @@ static ssize_t pagemap_read(struct file *file, char __user *buf,
 		ret = down_read_killable(&mm->mmap_sem);
 		if (ret)
 			goto out_free;
-		ret = walk_page_range(start_vaddr, end, &pagemap_walk);
+		ret = walk_page_range(mm, start_vaddr, end, &pagemap_ops, &pm);
 		up_read(&mm->mmap_sem);
 		start_vaddr = end;
 
@@ -1798,6 +1803,11 @@ static int gather_hugetlb_stats(pte_t *pte, unsigned long hmask,
 }
 #endif
 
+static const struct mm_walk_ops show_numa_ops = {
+	.hugetlb_entry = gather_hugetlb_stats,
+	.pmd_entry = gather_pte_stats,
+};
+
 /*
  * Display pages allocated per node and memory policy via /proc.
  */
@@ -1809,12 +1819,6 @@ static int show_numa_map(struct seq_file *m, void *v)
 	struct numa_maps *md = &numa_priv->md;
 	struct file *file = vma->vm_file;
 	struct mm_struct *mm = vma->vm_mm;
-	struct mm_walk walk = {
-		.hugetlb_entry = gather_hugetlb_stats,
-		.pmd_entry = gather_pte_stats,
-		.private = md,
-		.mm = mm,
-	};
 	struct mempolicy *pol;
 	char buffer[64];
 	int nid;
@@ -1848,7 +1852,7 @@ static int show_numa_map(struct seq_file *m, void *v)
 		seq_puts(m, " huge");
 
 	/* mmap_sem is held by m_start */
-	walk_page_vma(vma, &walk);
+	walk_page_vma(vma, &show_numa_ops, md);
 
 	if (!md->pages)
 		goto out;
diff --git a/include/linux/hmm.h b/include/linux/hmm.h
index 7ef56dc18050..3fec513b9c00 100644
--- a/include/linux/hmm.h
+++ b/include/linux/hmm.h
@@ -84,15 +84,12 @@
  * @notifiers: count of active mmu notifiers
  */
 struct hmm {
-	struct mm_struct	*mm;
-	struct kref		kref;
+	struct mmu_notifier	mmu_notifier;
 	spinlock_t		ranges_lock;
 	struct list_head	ranges;
 	struct list_head	mirrors;
-	struct mmu_notifier	mmu_notifier;
 	struct rw_semaphore	mirrors_sem;
 	wait_queue_head_t	wq;
-	struct rcu_head		rcu;
 	long			notifiers;
 };
 
@@ -158,13 +155,11 @@ enum hmm_pfn_value_e {
  * @values: pfn value for some special case (none, special, error, ...)
  * @default_flags: default flags for the range (write, read, ... see hmm doc)
  * @pfn_flags_mask: allows to mask pfn flags so that only default_flags matter
- * @page_shift: device virtual address shift value (should be >= PAGE_SHIFT)
  * @pfn_shifts: pfn shift value (should be <= PAGE_SHIFT)
  * @valid: pfns array did not change since it has been fill by an HMM function
  */
 struct hmm_range {
 	struct hmm		*hmm;
-	struct vm_area_struct	*vma;
 	struct list_head	list;
 	unsigned long		start;
 	unsigned long		end;
@@ -173,32 +168,11 @@ struct hmm_range {
 	const uint64_t		*values;
 	uint64_t		default_flags;
 	uint64_t		pfn_flags_mask;
-	uint8_t			page_shift;
 	uint8_t			pfn_shift;
 	bool			valid;
 };
 
 /*
- * hmm_range_page_shift() - return the page shift for the range
- * @range: range being queried
- * Return: page shift (page size = 1 << page shift) for the range
- */
-static inline unsigned hmm_range_page_shift(const struct hmm_range *range)
-{
-	return range->page_shift;
-}
-
-/*
- * hmm_range_page_size() - return the page size for the range
- * @range: range being queried
- * Return: page size for the range in bytes
- */
-static inline unsigned long hmm_range_page_size(const struct hmm_range *range)
-{
-	return 1UL << hmm_range_page_shift(range);
-}
-
-/*
  * hmm_range_wait_until_valid() - wait for range to be valid
  * @range: range affected by invalidation to wait on
  * @timeout: time out for wait in ms (ie abort wait after that period of time)
@@ -291,40 +265,6 @@ static inline uint64_t hmm_device_entry_from_pfn(const struct hmm_range *range,
 }
 
 /*
- * Old API:
- * hmm_pfn_to_page()
- * hmm_pfn_to_pfn()
- * hmm_pfn_from_page()
- * hmm_pfn_from_pfn()
- *
- * This are the OLD API please use new API, it is here to avoid cross-tree
- * merge painfullness ie we convert things to new API in stages.
- */
-static inline struct page *hmm_pfn_to_page(const struct hmm_range *range,
-					   uint64_t pfn)
-{
-	return hmm_device_entry_to_page(range, pfn);
-}
-
-static inline unsigned long hmm_pfn_to_pfn(const struct hmm_range *range,
-					   uint64_t pfn)
-{
-	return hmm_device_entry_to_pfn(range, pfn);
-}
-
-static inline uint64_t hmm_pfn_from_page(const struct hmm_range *range,
-					 struct page *page)
-{
-	return hmm_device_entry_from_page(range, page);
-}
-
-static inline uint64_t hmm_pfn_from_pfn(const struct hmm_range *range,
-					unsigned long pfn)
-{
-	return hmm_device_entry_from_pfn(range, pfn);
-}
-
-/*
  * Mirroring: how to synchronize device page table with CPU page table.
  *
  * A device driver that is participating in HMM mirroring must always
@@ -375,29 +315,6 @@ static inline uint64_t hmm_pfn_from_pfn(const struct hmm_range *range,
 struct hmm_mirror;
 
 /*
- * enum hmm_update_event - type of update
- * @HMM_UPDATE_INVALIDATE: invalidate range (no indication as to why)
- */
-enum hmm_update_event {
-	HMM_UPDATE_INVALIDATE,
-};
-
-/*
- * struct hmm_update - HMM update information for callback
- *
- * @start: virtual start address of the range to update
- * @end: virtual end address of the range to update
- * @event: event triggering the update (what is happening)
- * @blockable: can the callback block/sleep ?
- */
-struct hmm_update {
-	unsigned long start;
-	unsigned long end;
-	enum hmm_update_event event;
-	bool blockable;
-};
-
-/*
  * struct hmm_mirror_ops - HMM mirror device operations callback
  *
  * @update: callback to update range on a device
@@ -417,9 +334,9 @@ struct hmm_mirror_ops {
 	/* sync_cpu_device_pagetables() - synchronize page tables
 	 *
 	 * @mirror: pointer to struct hmm_mirror
-	 * @update: update information (see struct hmm_update)
-	 * Return: -EAGAIN if update.blockable false and callback need to
-	 *          block, 0 otherwise.
+	 * @update: update information (see struct mmu_notifier_range)
+	 * Return: -EAGAIN if mmu_notifier_range_blockable(update) is false
+	 * and callback needs to block, 0 otherwise.
 	 *
 	 * This callback ultimately originates from mmu_notifiers when the CPU
 	 * page table is updated. The device driver must update its page table
@@ -430,8 +347,9 @@ struct hmm_mirror_ops {
 	 * page tables are completely updated (TLBs flushed, etc); this is a
 	 * synchronous call.
 	 */
-	int (*sync_cpu_device_pagetables)(struct hmm_mirror *mirror,
-					  const struct hmm_update *update);
+	int (*sync_cpu_device_pagetables)(
+		struct hmm_mirror *mirror,
+		const struct mmu_notifier_range *update);
 };
 
 /*
@@ -457,20 +375,24 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror);
 /*
  * Please see Documentation/vm/hmm.rst for how to use the range API.
  */
-int hmm_range_register(struct hmm_range *range,
-		       struct hmm_mirror *mirror,
-		       unsigned long start,
-		       unsigned long end,
-		       unsigned page_shift);
+int hmm_range_register(struct hmm_range *range, struct hmm_mirror *mirror);
 void hmm_range_unregister(struct hmm_range *range);
-long hmm_range_snapshot(struct hmm_range *range);
-long hmm_range_fault(struct hmm_range *range, bool block);
+
+/*
+ * Retry fault if non-blocking, drop mmap_sem and return -EAGAIN in that case.
+ */
+#define HMM_FAULT_ALLOW_RETRY		(1 << 0)
+
+/* Don't fault in missing PTEs, just snapshot the current state. */
+#define HMM_FAULT_SNAPSHOT		(1 << 1)
+
+long hmm_range_fault(struct hmm_range *range, unsigned int flags);
+
 long hmm_range_dma_map(struct hmm_range *range,
 		       struct device *device,
 		       dma_addr_t *daddrs,
-		       bool block);
+		       unsigned int flags);
 long hmm_range_dma_unmap(struct hmm_range *range,
-			 struct vm_area_struct *vma,
 			 struct device *device,
 			 dma_addr_t *daddrs,
 			 bool dirty);
@@ -484,13 +406,6 @@ long hmm_range_dma_unmap(struct hmm_range *range,
  */
 #define HMM_RANGE_DEFAULT_TIMEOUT 1000
 
-/* Below are for HMM internal use only! Not to be used by device driver! */
-static inline void hmm_mm_init(struct mm_struct *mm)
-{
-	mm->hmm = NULL;
-}
-#else /* IS_ENABLED(CONFIG_HMM_MIRROR) */
-static inline void hmm_mm_init(struct mm_struct *mm) {}
 #endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */
 
 #endif /* LINUX_HMM_H */
diff --git a/include/linux/ioport.h b/include/linux/ioport.h
index 5b6a7121c9f0..7bddddfc76d6 100644
--- a/include/linux/ioport.h
+++ b/include/linux/ioport.h
@@ -297,6 +297,8 @@ static inline bool resource_overlaps(struct resource *r1, struct resource *r2)
 
 struct resource *devm_request_free_mem_region(struct device *dev,
 		struct resource *base, unsigned long size);
+struct resource *request_free_mem_region(struct resource *base,
+		unsigned long size, const char *name);
 
 #endif /* __ASSEMBLY__ */
 #endif	/* _LINUX_IOPORT_H */
diff --git a/include/linux/kernel.h b/include/linux/kernel.h
index 4fa360a13c1e..d83d403dac2e 100644
--- a/include/linux/kernel.h
+++ b/include/linux/kernel.h
@@ -217,7 +217,9 @@ extern void __cant_sleep(const char *file, int line, int preempt_offset);
  * might_sleep - annotation for functions that can sleep
  *
  * this macro will print a stack trace if it is executed in an atomic
- * context (spinlock, irq-handler, ...).
+ * context (spinlock, irq-handler, ...). Additional sections where blocking is
+ * not allowed can be annotated with non_block_start() and non_block_end()
+ * pairs.
  *
  * This is a useful debugging help to be able to catch problems early and not
  * be bitten later when the calling function happens to sleep when it is not
@@ -233,6 +235,23 @@ extern void __cant_sleep(const char *file, int line, int preempt_offset);
 # define cant_sleep() \
 	do { __cant_sleep(__FILE__, __LINE__, 0); } while (0)
 # define sched_annotate_sleep()	(current->task_state_change = 0)
+/**
+ * non_block_start - annotate the start of section where sleeping is prohibited
+ *
+ * This is on behalf of the oom reaper, specifically when it is calling the mmu
+ * notifiers. The problem is that if the notifier were to block on, for example,
+ * mutex_lock() and if the process which holds that mutex were to perform a
+ * sleeping memory allocation, the oom reaper is now blocked on completion of
+ * that memory allocation. Other blocking calls like wait_event() pose similar
+ * issues.
+ */
+# define non_block_start() (current->non_block_count++)
+/**
+ * non_block_end - annotate the end of section where sleeping is prohibited
+ *
+ * Closes a section opened by non_block_start().
+ */
+# define non_block_end() WARN_ON(current->non_block_count-- == 0)
 #else
   static inline void ___might_sleep(const char *file, int line,
 				   int preempt_offset) { }
@@ -241,6 +260,8 @@ extern void __cant_sleep(const char *file, int line, int preempt_offset);
 # define might_sleep() do { might_resched(); } while (0)
 # define cant_sleep() do { } while (0)
 # define sched_annotate_sleep() do { } while (0)
+# define non_block_start() do { } while (0)
+# define non_block_end() do { } while (0)
 #endif
 
 #define might_sleep_if(cond) do { if (cond) might_sleep(); } while (0)
diff --git a/include/linux/memremap.h b/include/linux/memremap.h
index f8a5b2a19945..fb2a0bd826b9 100644
--- a/include/linux/memremap.h
+++ b/include/linux/memremap.h
@@ -109,7 +109,6 @@ struct dev_pagemap {
 	struct percpu_ref *ref;
 	struct percpu_ref internal_ref;
 	struct completion done;
-	struct device *dev;
 	enum memory_type type;
 	unsigned int flags;
 	u64 pci_p2pdma_bus_offset;
@@ -124,6 +123,8 @@ static inline struct vmem_altmap *pgmap_altmap(struct dev_pagemap *pgmap)
 }
 
 #ifdef CONFIG_ZONE_DEVICE
+void *memremap_pages(struct dev_pagemap *pgmap, int nid);
+void memunmap_pages(struct dev_pagemap *pgmap);
 void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap);
 void devm_memunmap_pages(struct device *dev, struct dev_pagemap *pgmap);
 struct dev_pagemap *get_dev_pagemap(unsigned long pfn,
diff --git a/include/linux/migrate.h b/include/linux/migrate.h
index 7f04754c7f2b..72120061b7d4 100644
--- a/include/linux/migrate.h
+++ b/include/linux/migrate.h
@@ -166,8 +166,6 @@ static inline int migrate_misplaced_transhuge_page(struct mm_struct *mm,
 #define MIGRATE_PFN_MIGRATE	(1UL << 1)
 #define MIGRATE_PFN_LOCKED	(1UL << 2)
 #define MIGRATE_PFN_WRITE	(1UL << 3)
-#define MIGRATE_PFN_DEVICE	(1UL << 4)
-#define MIGRATE_PFN_ERROR	(1UL << 5)
 #define MIGRATE_PFN_SHIFT	6
 
 static inline struct page *migrate_pfn_to_page(unsigned long mpfn)
@@ -182,107 +180,27 @@ static inline unsigned long migrate_pfn(unsigned long pfn)
 	return (pfn << MIGRATE_PFN_SHIFT) | MIGRATE_PFN_VALID;
 }
 
-/*
- * struct migrate_vma_ops - migrate operation callback
- *
- * @alloc_and_copy: alloc destination memory and copy source memory to it
- * @finalize_and_map: allow caller to map the successfully migrated pages
- *
- *
- * The alloc_and_copy() callback happens once all source pages have been locked,
- * unmapped and checked (checked whether pinned or not). All pages that can be
- * migrated will have an entry in the src array set with the pfn value of the
- * page and with the MIGRATE_PFN_VALID and MIGRATE_PFN_MIGRATE flag set (other
- * flags might be set but should be ignored by the callback).
- *
- * The alloc_and_copy() callback can then allocate destination memory and copy
- * source memory to it for all those entries (ie with MIGRATE_PFN_VALID and
- * MIGRATE_PFN_MIGRATE flag set). Once these are allocated and copied, the
- * callback must update each corresponding entry in the dst array with the pfn
- * value of the destination page and with the MIGRATE_PFN_VALID and
- * MIGRATE_PFN_LOCKED flags set (destination pages must have their struct pages
- * locked, via lock_page()).
- *
- * At this point the alloc_and_copy() callback is done and returns.
- *
- * Note that the callback does not have to migrate all the pages that are
- * marked with MIGRATE_PFN_MIGRATE flag in src array unless this is a migration
- * from device memory to system memory (ie the MIGRATE_PFN_DEVICE flag is also
- * set in the src array entry). If the device driver cannot migrate a device
- * page back to system memory, then it must set the corresponding dst array
- * entry to MIGRATE_PFN_ERROR. This will trigger a SIGBUS if CPU tries to
- * access any of the virtual addresses originally backed by this page. Because
- * a SIGBUS is such a severe result for the userspace process, the device
- * driver should avoid setting MIGRATE_PFN_ERROR unless it is really in an
- * unrecoverable state.
- *
- * For empty entry inside CPU page table (pte_none() or pmd_none() is true) we
- * do set MIGRATE_PFN_MIGRATE flag inside the corresponding source array thus
- * allowing device driver to allocate device memory for those unback virtual
- * address. For this the device driver simply have to allocate device memory
- * and properly set the destination entry like for regular migration. Note that
- * this can still fails and thus inside the device driver must check if the
- * migration was successful for those entry inside the finalize_and_map()
- * callback just like for regular migration.
- *
- * THE alloc_and_copy() CALLBACK MUST NOT CHANGE ANY OF THE SRC ARRAY ENTRIES
- * OR BAD THINGS WILL HAPPEN !
- *
- *
- * The finalize_and_map() callback happens after struct page migration from
- * source to destination (destination struct pages are the struct pages for the
- * memory allocated by the alloc_and_copy() callback).  Migration can fail, and
- * thus the finalize_and_map() allows the driver to inspect which pages were
- * successfully migrated, and which were not. Successfully migrated pages will
- * have the MIGRATE_PFN_MIGRATE flag set for their src array entry.
- *
- * It is safe to update device page table from within the finalize_and_map()
- * callback because both destination and source page are still locked, and the
- * mmap_sem is held in read mode (hence no one can unmap the range being
- * migrated).
- *
- * Once callback is done cleaning up things and updating its page table (if it
- * chose to do so, this is not an obligation) then it returns. At this point,
- * the HMM core will finish up the final steps, and the migration is complete.
- *
- * THE finalize_and_map() CALLBACK MUST NOT CHANGE ANY OF THE SRC OR DST ARRAY
- * ENTRIES OR BAD THINGS WILL HAPPEN !
- */
-struct migrate_vma_ops {
-	void (*alloc_and_copy)(struct vm_area_struct *vma,
-			       const unsigned long *src,
-			       unsigned long *dst,
-			       unsigned long start,
-			       unsigned long end,
-			       void *private);
-	void (*finalize_and_map)(struct vm_area_struct *vma,
-				 const unsigned long *src,
-				 const unsigned long *dst,
-				 unsigned long start,
-				 unsigned long end,
-				 void *private);
+struct migrate_vma {
+	struct vm_area_struct	*vma;
+	/*
+	 * Both src and dst array must be big enough for
+	 * (end - start) >> PAGE_SHIFT entries.
+	 *
+	 * The src array must not be modified by the caller after
+	 * migrate_vma_setup(), and must not change the dst array after
+	 * migrate_vma_pages() returns.
+	 */
+	unsigned long		*dst;
+	unsigned long		*src;
+	unsigned long		cpages;
+	unsigned long		npages;
+	unsigned long		start;
+	unsigned long		end;
 };
 
-#if defined(CONFIG_MIGRATE_VMA_HELPER)
-int migrate_vma(const struct migrate_vma_ops *ops,
-		struct vm_area_struct *vma,
-		unsigned long start,
-		unsigned long end,
-		unsigned long *src,
-		unsigned long *dst,
-		void *private);
-#else
-static inline int migrate_vma(const struct migrate_vma_ops *ops,
-			      struct vm_area_struct *vma,
-			      unsigned long start,
-			      unsigned long end,
-			      unsigned long *src,
-			      unsigned long *dst,
-			      void *private)
-{
-	return -EINVAL;
-}
-#endif /* IS_ENABLED(CONFIG_MIGRATE_VMA_HELPER) */
+int migrate_vma_setup(struct migrate_vma *args);
+void migrate_vma_pages(struct migrate_vma *migrate);
+void migrate_vma_finalize(struct migrate_vma *migrate);
 
 #endif /* CONFIG_MIGRATION */
 
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 0334ca97c584..7cf955feb823 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1430,54 +1430,8 @@ void zap_page_range(struct vm_area_struct *vma, unsigned long address,
 void unmap_vmas(struct mmu_gather *tlb, struct vm_area_struct *start_vma,
 		unsigned long start, unsigned long end);
 
-/**
- * mm_walk - callbacks for walk_page_range
- * @pud_entry: if set, called for each non-empty PUD (2nd-level) entry
- *	       this handler should only handle pud_trans_huge() puds.
- *	       the pmd_entry or pte_entry callbacks will be used for
- *	       regular PUDs.
- * @pmd_entry: if set, called for each non-empty PMD (3rd-level) entry
- *	       this handler is required to be able to handle
- *	       pmd_trans_huge() pmds.  They may simply choose to
- *	       split_huge_page() instead of handling it explicitly.
- * @pte_entry: if set, called for each non-empty PTE (4th-level) entry
- * @pte_hole: if set, called for each hole at all levels
- * @hugetlb_entry: if set, called for each hugetlb entry
- * @test_walk: caller specific callback function to determine whether
- *             we walk over the current vma or not. Returning 0
- *             value means "do page table walk over the current vma,"
- *             and a negative one means "abort current page table walk
- *             right now." 1 means "skip the current vma."
- * @mm:        mm_struct representing the target process of page table walk
- * @vma:       vma currently walked (NULL if walking outside vmas)
- * @private:   private data for callbacks' usage
- *
- * (see the comment on walk_page_range() for more details)
- */
-struct mm_walk {
-	int (*pud_entry)(pud_t *pud, unsigned long addr,
-			 unsigned long next, struct mm_walk *walk);
-	int (*pmd_entry)(pmd_t *pmd, unsigned long addr,
-			 unsigned long next, struct mm_walk *walk);
-	int (*pte_entry)(pte_t *pte, unsigned long addr,
-			 unsigned long next, struct mm_walk *walk);
-	int (*pte_hole)(unsigned long addr, unsigned long next,
-			struct mm_walk *walk);
-	int (*hugetlb_entry)(pte_t *pte, unsigned long hmask,
-			     unsigned long addr, unsigned long next,
-			     struct mm_walk *walk);
-	int (*test_walk)(unsigned long addr, unsigned long next,
-			struct mm_walk *walk);
-	struct mm_struct *mm;
-	struct vm_area_struct *vma;
-	void *private;
-};
-
 struct mmu_notifier_range;
 
-int walk_page_range(unsigned long addr, unsigned long end,
-		struct mm_walk *walk);
-int walk_page_vma(struct vm_area_struct *vma, struct mm_walk *walk);
 void free_pgd_range(struct mmu_gather *tlb, unsigned long addr,
 		unsigned long end, unsigned long floor, unsigned long ceiling);
 int copy_page_range(struct mm_struct *dst, struct mm_struct *src,
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 6a7a1083b6fb..0b739f360cec 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -25,7 +25,6 @@
 
 struct address_space;
 struct mem_cgroup;
-struct hmm;
 
 /*
  * Each physical page in the system has a struct page associated with
@@ -511,11 +510,6 @@ struct mm_struct {
 		atomic_long_t hugetlb_usage;
 #endif
 		struct work_struct async_put_work;
-
-#ifdef CONFIG_HMM_MIRROR
-		/* HMM needs to track a few things per mm */
-		struct hmm *hmm;
-#endif
 	} __randomize_layout;
 
 	/*
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index b6c004bd9f6a..1bd8e6a09a3c 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -42,6 +42,10 @@ enum mmu_notifier_event {
 
 #ifdef CONFIG_MMU_NOTIFIER
 
+#ifdef CONFIG_LOCKDEP
+extern struct lockdep_map __mmu_notifier_invalidate_range_start_map;
+#endif
+
 /*
  * The mmu notifier_mm structure is allocated and installed in
  * mm->mmu_notifier_mm inside the mm_take_all_locks() protected
@@ -211,6 +215,19 @@ struct mmu_notifier_ops {
 	 */
 	void (*invalidate_range)(struct mmu_notifier *mn, struct mm_struct *mm,
 				 unsigned long start, unsigned long end);
+
+	/*
+	 * These callbacks are used with the get/put interface to manage the
+	 * lifetime of the mmu_notifier memory. alloc_notifier() returns a new
+	 * notifier for use with the mm.
+	 *
+	 * free_notifier() is only called after the mmu_notifier has been
+	 * fully put, calls to any ops callback are prevented and no ops
+	 * callbacks are currently running. It is called from a SRCU callback
+	 * and cannot sleep.
+	 */
+	struct mmu_notifier *(*alloc_notifier)(struct mm_struct *mm);
+	void (*free_notifier)(struct mmu_notifier *mn);
 };
 
 /*
@@ -227,6 +244,9 @@ struct mmu_notifier_ops {
 struct mmu_notifier {
 	struct hlist_node hlist;
 	const struct mmu_notifier_ops *ops;
+	struct mm_struct *mm;
+	struct rcu_head rcu;
+	unsigned int users;
 };
 
 static inline int mm_has_notifiers(struct mm_struct *mm)
@@ -234,14 +254,27 @@ static inline int mm_has_notifiers(struct mm_struct *mm)
 	return unlikely(mm->mmu_notifier_mm);
 }
 
+struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
+					     struct mm_struct *mm);
+static inline struct mmu_notifier *
+mmu_notifier_get(const struct mmu_notifier_ops *ops, struct mm_struct *mm)
+{
+	struct mmu_notifier *ret;
+
+	down_write(&mm->mmap_sem);
+	ret = mmu_notifier_get_locked(ops, mm);
+	up_write(&mm->mmap_sem);
+	return ret;
+}
+void mmu_notifier_put(struct mmu_notifier *mn);
+void mmu_notifier_synchronize(void);
+
 extern int mmu_notifier_register(struct mmu_notifier *mn,
 				 struct mm_struct *mm);
 extern int __mmu_notifier_register(struct mmu_notifier *mn,
 				   struct mm_struct *mm);
 extern void mmu_notifier_unregister(struct mmu_notifier *mn,
 				    struct mm_struct *mm);
-extern void mmu_notifier_unregister_no_release(struct mmu_notifier *mn,
-					       struct mm_struct *mm);
 extern void __mmu_notifier_mm_destroy(struct mm_struct *mm);
 extern void __mmu_notifier_release(struct mm_struct *mm);
 extern int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
@@ -310,25 +343,36 @@ static inline void mmu_notifier_change_pte(struct mm_struct *mm,
 static inline void
 mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range)
 {
+	might_sleep();
+
+	lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
 	if (mm_has_notifiers(range->mm)) {
 		range->flags |= MMU_NOTIFIER_RANGE_BLOCKABLE;
 		__mmu_notifier_invalidate_range_start(range);
 	}
+	lock_map_release(&__mmu_notifier_invalidate_range_start_map);
 }
 
 static inline int
 mmu_notifier_invalidate_range_start_nonblock(struct mmu_notifier_range *range)
 {
+	int ret = 0;
+
+	lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
 	if (mm_has_notifiers(range->mm)) {
 		range->flags &= ~MMU_NOTIFIER_RANGE_BLOCKABLE;
-		return __mmu_notifier_invalidate_range_start(range);
+		ret = __mmu_notifier_invalidate_range_start(range);
 	}
-	return 0;
+	lock_map_release(&__mmu_notifier_invalidate_range_start_map);
+	return ret;
 }
 
 static inline void
 mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range)
 {
+	if (mmu_notifier_range_blockable(range))
+		might_sleep();
+
 	if (mm_has_notifiers(range->mm))
 		__mmu_notifier_invalidate_range_end(range, false);
 }
@@ -482,9 +526,6 @@ static inline void mmu_notifier_range_init(struct mmu_notifier_range *range,
 	set_pte_at(___mm, ___address, __ptep, ___pte);			\
 })
 
-extern void mmu_notifier_call_srcu(struct rcu_head *rcu,
-				   void (*func)(struct rcu_head *rcu));
-
 #else /* CONFIG_MMU_NOTIFIER */
 
 struct mmu_notifier_range {
@@ -581,6 +622,10 @@ static inline void mmu_notifier_mm_destroy(struct mm_struct *mm)
 #define pudp_huge_clear_flush_notify pudp_huge_clear_flush
 #define set_pte_at_notify set_pte_at
 
+static inline void mmu_notifier_synchronize(void)
+{
+}
+
 #endif /* CONFIG_MMU_NOTIFIER */
 
 #endif /* _LINUX_MMU_NOTIFIER_H */
diff --git a/include/linux/pagewalk.h b/include/linux/pagewalk.h
new file mode 100644
index 000000000000..bddd9759bab9
--- /dev/null
+++ b/include/linux/pagewalk.h
@@ -0,0 +1,66 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef _LINUX_PAGEWALK_H
+#define _LINUX_PAGEWALK_H
+
+#include <linux/mm.h>
+
+struct mm_walk;
+
+/**
+ * mm_walk_ops - callbacks for walk_page_range
+ * @pud_entry:		if set, called for each non-empty PUD (2nd-level) entry
+ *			this handler should only handle pud_trans_huge() puds.
+ *			the pmd_entry or pte_entry callbacks will be used for
+ *			regular PUDs.
+ * @pmd_entry:		if set, called for each non-empty PMD (3rd-level) entry
+ *			this handler is required to be able to handle
+ *			pmd_trans_huge() pmds.  They may simply choose to
+ *			split_huge_page() instead of handling it explicitly.
+ * @pte_entry:		if set, called for each non-empty PTE (4th-level) entry
+ * @pte_hole:		if set, called for each hole at all levels
+ * @hugetlb_entry:	if set, called for each hugetlb entry
+ * @test_walk:		caller specific callback function to determine whether
+ *			we walk over the current vma or not. Returning 0 means
+ *			"do page table walk over the current vma", returning
+ *			a negative value means "abort current page table walk
+ *			right now" and returning 1 means "skip the current vma"
+ */
+struct mm_walk_ops {
+	int (*pud_entry)(pud_t *pud, unsigned long addr,
+			 unsigned long next, struct mm_walk *walk);
+	int (*pmd_entry)(pmd_t *pmd, unsigned long addr,
+			 unsigned long next, struct mm_walk *walk);
+	int (*pte_entry)(pte_t *pte, unsigned long addr,
+			 unsigned long next, struct mm_walk *walk);
+	int (*pte_hole)(unsigned long addr, unsigned long next,
+			struct mm_walk *walk);
+	int (*hugetlb_entry)(pte_t *pte, unsigned long hmask,
+			     unsigned long addr, unsigned long next,
+			     struct mm_walk *walk);
+	int (*test_walk)(unsigned long addr, unsigned long next,
+			struct mm_walk *walk);
+};
+
+/**
+ * mm_walk - walk_page_range data
+ * @ops:	operation to call during the walk
+ * @mm:		mm_struct representing the target process of page table walk
+ * @vma:	vma currently walked (NULL if walking outside vmas)
+ * @private:	private data for callbacks' usage
+ *
+ * (see the comment on walk_page_range() for more details)
+ */
+struct mm_walk {
+	const struct mm_walk_ops *ops;
+	struct mm_struct *mm;
+	struct vm_area_struct *vma;
+	void *private;
+};
+
+int walk_page_range(struct mm_struct *mm, unsigned long start,
+		unsigned long end, const struct mm_walk_ops *ops,
+		void *private);
+int walk_page_vma(struct vm_area_struct *vma, const struct mm_walk_ops *ops,
+		void *private);
+
+#endif /* _LINUX_PAGEWALK_H */
diff --git a/include/linux/sched.h b/include/linux/sched.h
index b75b28287005..70db597d6fd4 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -958,6 +958,10 @@ struct task_struct {
 	struct mutex_waiter		*blocked_on;
 #endif
 
+#ifdef CONFIG_DEBUG_ATOMIC_SLEEP
+	int				non_block_count;
+#endif
+
 #ifdef CONFIG_TRACE_IRQFLAGS
 	unsigned int			irq_events;
 	unsigned long			hardirq_enable_ip;
diff --git a/include/rdma/ib_umem.h b/include/rdma/ib_umem.h
index 1052d0d62be7..a91b2af64ec4 100644
--- a/include/rdma/ib_umem.h
+++ b/include/rdma/ib_umem.h
@@ -42,7 +42,7 @@ struct ib_ucontext;
 struct ib_umem_odp;
 
 struct ib_umem {
-	struct ib_ucontext     *context;
+	struct ib_device       *ibdev;
 	struct mm_struct       *owning_mm;
 	size_t			length;
 	unsigned long		address;
diff --git a/include/rdma/ib_umem_odp.h b/include/rdma/ib_umem_odp.h
index 479db5c98ff6..253df1a1fa54 100644
--- a/include/rdma/ib_umem_odp.h
+++ b/include/rdma/ib_umem_odp.h
@@ -37,11 +37,6 @@
 #include <rdma/ib_verbs.h>
 #include <linux/interval_tree.h>
 
-struct umem_odp_node {
-	u64 __subtree_last;
-	struct rb_node rb;
-};
-
 struct ib_umem_odp {
 	struct ib_umem umem;
 	struct ib_ucontext_per_mm *per_mm;
@@ -72,7 +67,15 @@ struct ib_umem_odp {
 	int npages;
 
 	/* Tree tracking */
-	struct umem_odp_node	interval_tree;
+	struct interval_tree_node interval_tree;
+
+	/*
+	 * An implicit odp umem cannot be DMA mapped, has 0 length, and serves
+	 * only as an anchor for the driver to hold onto the per_mm. FIXME:
+	 * This should be removed and drivers should work with the per_mm
+	 * directly.
+	 */
+	bool is_implicit_odp;
 
 	struct completion	notifier_completion;
 	int			dying;
@@ -88,14 +91,13 @@ static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem)
 /* Returns the first page of an ODP umem. */
 static inline unsigned long ib_umem_start(struct ib_umem_odp *umem_odp)
 {
-	return ALIGN_DOWN(umem_odp->umem.address, 1UL << umem_odp->page_shift);
+	return umem_odp->interval_tree.start;
 }
 
 /* Returns the address of the page after the last one of an ODP umem. */
 static inline unsigned long ib_umem_end(struct ib_umem_odp *umem_odp)
 {
-	return ALIGN(umem_odp->umem.address + umem_odp->umem.length,
-		     1UL << umem_odp->page_shift);
+	return umem_odp->interval_tree.last + 1;
 }
 
 static inline size_t ib_umem_odp_num_pages(struct ib_umem_odp *umem_odp)
@@ -120,25 +122,20 @@ static inline size_t ib_umem_odp_num_pages(struct ib_umem_odp *umem_odp)
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
 
 struct ib_ucontext_per_mm {
-	struct ib_ucontext *context;
-	struct mm_struct *mm;
+	struct mmu_notifier mn;
 	struct pid *tgid;
-	bool active;
 
 	struct rb_root_cached umem_tree;
 	/* Protects umem_tree */
 	struct rw_semaphore umem_rwsem;
-
-	struct mmu_notifier mn;
-	unsigned int odp_mrs_count;
-
-	struct list_head ucontext_list;
-	struct rcu_head rcu;
 };
 
-int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
-struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root_umem,
-				      unsigned long addr, size_t size);
+struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
+				    size_t size, int access);
+struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
+					       int access);
+struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root_umem,
+					    unsigned long addr, size_t size);
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
 
 int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 start_offset,
@@ -163,8 +160,17 @@ int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
  * Find first region intersecting with address range.
  * Return NULL if not found
  */
-struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
-				       u64 addr, u64 length);
+static inline struct ib_umem_odp *
+rbt_ib_umem_lookup(struct rb_root_cached *root, u64 addr, u64 length)
+{
+	struct interval_tree_node *node;
+
+	node = interval_tree_iter_first(root, addr, addr + length - 1);
+	if (!node)
+		return NULL;
+	return container_of(node, struct ib_umem_odp, interval_tree);
+
+}
 
 static inline int ib_umem_mmu_notifier_retry(struct ib_umem_odp *umem_odp,
 					     unsigned long mmu_seq)
@@ -185,9 +191,11 @@ static inline int ib_umem_mmu_notifier_retry(struct ib_umem_odp *umem_odp,
 
 #else /* CONFIG_INFINIBAND_ON_DEMAND_PAGING */
 
-static inline int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
+static inline struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata,
+						  unsigned long addr,
+						  size_t size, int access)
 {
-	return -EINVAL;
+	return ERR_PTR(-EINVAL);
 }
 
 static inline void ib_umem_odp_release(struct ib_umem_odp *umem_odp) {}
diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h
index 4f225175cb91..f659f4a02aa9 100644
--- a/include/rdma/ib_verbs.h
+++ b/include/rdma/ib_verbs.h
@@ -1417,11 +1417,6 @@ struct ib_ucontext {
 
 	bool cleanup_retryable;
 
-	void (*invalidate_range)(struct ib_umem_odp *umem_odp,
-				 unsigned long start, unsigned long end);
-	struct mutex per_mm_list_lock;
-	struct list_head per_mm_list;
-
 	struct ib_rdmacg_object	cg_obj;
 	/*
 	 * Implementation details of the RDMA core, don't use in drivers:
@@ -2378,6 +2373,8 @@ struct ib_device_ops {
 			    u64 iova);
 	int (*unmap_fmr)(struct list_head *fmr_list);
 	int (*dealloc_fmr)(struct ib_fmr *fmr);
+	void (*invalidate_range)(struct ib_umem_odp *umem_odp,
+				 unsigned long start, unsigned long end);
 	int (*attach_mcast)(struct ib_qp *qp, union ib_gid *gid, u16 lid);
 	int (*detach_mcast)(struct ib_qp *qp, union ib_gid *gid, u16 lid);
 	struct ib_xrcd *(*alloc_xrcd)(struct ib_device *device,
diff --git a/kernel/fork.c b/kernel/fork.c
index 53e780748fe3..5a0fd518e04e 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -1009,7 +1009,6 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
 	mm_init_owner(mm, p);
 	RCU_INIT_POINTER(mm->exe_file, NULL);
 	mmu_notifier_mm_init(mm);
-	hmm_mm_init(mm);
 	init_tlb_flush_pending(mm);
 #if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS
 	mm->pmd_huge_pte = NULL;
diff --git a/kernel/resource.c b/kernel/resource.c
index 7ea4306503c5..74877e9d90ca 100644
--- a/kernel/resource.c
+++ b/kernel/resource.c
@@ -1644,19 +1644,8 @@ void resource_list_free(struct list_head *head)
 EXPORT_SYMBOL(resource_list_free);
 
 #ifdef CONFIG_DEVICE_PRIVATE
-/**
- * devm_request_free_mem_region - find free region for device private memory
- *
- * @dev: device struct to bind the resource to
- * @size: size in bytes of the device memory to add
- * @base: resource tree to look in
- *
- * This function tries to find an empty range of physical address big enough to
- * contain the new resource, so that it can later be hotplugged as ZONE_DEVICE
- * memory, which in turn allocates struct pages.
- */
-struct resource *devm_request_free_mem_region(struct device *dev,
-		struct resource *base, unsigned long size)
+static struct resource *__request_free_mem_region(struct device *dev,
+		struct resource *base, unsigned long size, const char *name)
 {
 	resource_size_t end, addr;
 	struct resource *res;
@@ -1670,7 +1659,10 @@ struct resource *devm_request_free_mem_region(struct device *dev,
 				REGION_DISJOINT)
 			continue;
 
-		res = devm_request_mem_region(dev, addr, size, dev_name(dev));
+		if (dev)
+			res = devm_request_mem_region(dev, addr, size, name);
+		else
+			res = request_mem_region(addr, size, name);
 		if (!res)
 			return ERR_PTR(-ENOMEM);
 		res->desc = IORES_DESC_DEVICE_PRIVATE_MEMORY;
@@ -1679,7 +1671,32 @@ struct resource *devm_request_free_mem_region(struct device *dev,
 
 	return ERR_PTR(-ERANGE);
 }
+
+/**
+ * devm_request_free_mem_region - find free region for device private memory
+ *
+ * @dev: device struct to bind the resource to
+ * @size: size in bytes of the device memory to add
+ * @base: resource tree to look in
+ *
+ * This function tries to find an empty range of physical address big enough to
+ * contain the new resource, so that it can later be hotplugged as ZONE_DEVICE
+ * memory, which in turn allocates struct pages.
+ */
+struct resource *devm_request_free_mem_region(struct device *dev,
+		struct resource *base, unsigned long size)
+{
+	return __request_free_mem_region(dev, base, size, dev_name(dev));
+}
 EXPORT_SYMBOL_GPL(devm_request_free_mem_region);
+
+struct resource *request_free_mem_region(struct resource *base,
+		unsigned long size, const char *name)
+{
+	return __request_free_mem_region(NULL, base, size, name);
+}
+EXPORT_SYMBOL_GPL(request_free_mem_region);
+
 #endif /* CONFIG_DEVICE_PRIVATE */
 
 static int __init strict_iomem(char *str)
diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index 5e8387bdd09c..f9a1346a5fa9 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -3871,13 +3871,22 @@ static noinline void __schedule_bug(struct task_struct *prev)
 /*
  * Various schedule()-time debugging checks and statistics:
  */
-static inline void schedule_debug(struct task_struct *prev)
+static inline void schedule_debug(struct task_struct *prev, bool preempt)
 {
 #ifdef CONFIG_SCHED_STACK_END_CHECK
 	if (task_stack_end_corrupted(prev))
 		panic("corrupted stack end detected inside scheduler\n");
 #endif
 
+#ifdef CONFIG_DEBUG_ATOMIC_SLEEP
+	if (!preempt && prev->state && prev->non_block_count) {
+		printk(KERN_ERR "BUG: scheduling in a non-blocking section: %s/%d/%i\n",
+			prev->comm, prev->pid, prev->non_block_count);
+		dump_stack();
+		add_taint(TAINT_WARN, LOCKDEP_STILL_OK);
+	}
+#endif
+
 	if (unlikely(in_atomic_preempt_off())) {
 		__schedule_bug(prev);
 		preempt_count_set(PREEMPT_DISABLED);
@@ -3989,7 +3998,7 @@ static void __sched notrace __schedule(bool preempt)
 	rq = cpu_rq(cpu);
 	prev = rq->curr;
 
-	schedule_debug(prev);
+	schedule_debug(prev, preempt);
 
 	if (sched_feat(HRTICK))
 		hrtick_clear(rq);
@@ -6763,7 +6772,7 @@ void ___might_sleep(const char *file, int line, int preempt_offset)
 	rcu_sleep_check();
 
 	if ((preempt_count_equals(preempt_offset) && !irqs_disabled() &&
-	     !is_idle_task(current)) ||
+	     !is_idle_task(current) && !current->non_block_count) ||
 	    system_state == SYSTEM_BOOTING || system_state > SYSTEM_RUNNING ||
 	    oops_in_progress)
 		return;
@@ -6779,8 +6788,8 @@ void ___might_sleep(const char *file, int line, int preempt_offset)
 		"BUG: sleeping function called from invalid context at %s:%d\n",
 			file, line);
 	printk(KERN_ERR
-		"in_atomic(): %d, irqs_disabled(): %d, pid: %d, name: %s\n",
-			in_atomic(), irqs_disabled(),
+		"in_atomic(): %d, irqs_disabled(): %d, non_block: %d, pid: %d, name: %s\n",
+			in_atomic(), irqs_disabled(), current->non_block_count,
 			current->pid, current->comm);
 
 	if (task_stack_end_corrupted(current))
diff --git a/mm/Kconfig b/mm/Kconfig
index 56cec636a1fc..2fe4902ad755 100644
--- a/mm/Kconfig
+++ b/mm/Kconfig
@@ -669,23 +669,17 @@ config ZONE_DEVICE
 
 	  If FS_DAX is enabled, then say Y.
 
-config MIGRATE_VMA_HELPER
-	bool
-
 config DEV_PAGEMAP_OPS
 	bool
 
+#
+# Helpers to mirror range of the CPU page tables of a process into device page
+# tables.
+#
 config HMM_MIRROR
-	bool "HMM mirror CPU page table into a device page table"
-	depends on (X86_64 || PPC64)
-	depends on MMU && 64BIT
-	select MMU_NOTIFIER
-	help
-	  Select HMM_MIRROR if you want to mirror range of the CPU page table of a
-	  process into a device page table. Here, mirror means "keep synchronized".
-	  Prerequisites: the device must provide the ability to write-protect its
-	  page tables (at PAGE_SIZE granularity), and must be able to recover from
-	  the resulting potential page faults.
+	bool
+	depends on MMU
+	depends on MMU_NOTIFIER
 
 config DEVICE_PRIVATE
 	bool "Unaddressable device memory (GPU memory, ...)"
diff --git a/mm/hmm.c b/mm/hmm.c
index 16b6731a34db..902f5fa6bf93 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -8,7 +8,7 @@
  * Refer to include/linux/hmm.h for information about heterogeneous memory
  * management or HMM for short.
  */
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/hmm.h>
 #include <linux/init.h>
 #include <linux/rmap.h>
@@ -26,101 +26,37 @@
 #include <linux/mmu_notifier.h>
 #include <linux/memory_hotplug.h>
 
-static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
-
-/**
- * hmm_get_or_create - register HMM against an mm (HMM internal)
- *
- * @mm: mm struct to attach to
- * Returns: returns an HMM object, either by referencing the existing
- *          (per-process) object, or by creating a new one.
- *
- * This is not intended to be used directly by device drivers. If mm already
- * has an HMM struct then it get a reference on it and returns it. Otherwise
- * it allocates an HMM struct, initializes it, associate it with the mm and
- * returns it.
- */
-static struct hmm *hmm_get_or_create(struct mm_struct *mm)
+static struct mmu_notifier *hmm_alloc_notifier(struct mm_struct *mm)
 {
 	struct hmm *hmm;
 
-	lockdep_assert_held_write(&mm->mmap_sem);
-
-	/* Abuse the page_table_lock to also protect mm->hmm. */
-	spin_lock(&mm->page_table_lock);
-	hmm = mm->hmm;
-	if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
-		goto out_unlock;
-	spin_unlock(&mm->page_table_lock);
-
-	hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
+	hmm = kzalloc(sizeof(*hmm), GFP_KERNEL);
 	if (!hmm)
-		return NULL;
+		return ERR_PTR(-ENOMEM);
+
 	init_waitqueue_head(&hmm->wq);
 	INIT_LIST_HEAD(&hmm->mirrors);
 	init_rwsem(&hmm->mirrors_sem);
-	hmm->mmu_notifier.ops = NULL;
 	INIT_LIST_HEAD(&hmm->ranges);
 	spin_lock_init(&hmm->ranges_lock);
-	kref_init(&hmm->kref);
 	hmm->notifiers = 0;
-	hmm->mm = mm;
-
-	hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
-	if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
-		kfree(hmm);
-		return NULL;
-	}
-
-	mmgrab(hmm->mm);
-
-	/*
-	 * We hold the exclusive mmap_sem here so we know that mm->hmm is
-	 * still NULL or 0 kref, and is safe to update.
-	 */
-	spin_lock(&mm->page_table_lock);
-	mm->hmm = hmm;
-
-out_unlock:
-	spin_unlock(&mm->page_table_lock);
-	return hmm;
+	return &hmm->mmu_notifier;
 }
 
-static void hmm_free_rcu(struct rcu_head *rcu)
+static void hmm_free_notifier(struct mmu_notifier *mn)
 {
-	struct hmm *hmm = container_of(rcu, struct hmm, rcu);
+	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
 
-	mmdrop(hmm->mm);
+	WARN_ON(!list_empty(&hmm->ranges));
+	WARN_ON(!list_empty(&hmm->mirrors));
 	kfree(hmm);
 }
 
-static void hmm_free(struct kref *kref)
-{
-	struct hmm *hmm = container_of(kref, struct hmm, kref);
-
-	spin_lock(&hmm->mm->page_table_lock);
-	if (hmm->mm->hmm == hmm)
-		hmm->mm->hmm = NULL;
-	spin_unlock(&hmm->mm->page_table_lock);
-
-	mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
-	mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
-}
-
-static inline void hmm_put(struct hmm *hmm)
-{
-	kref_put(&hmm->kref, hmm_free);
-}
-
 static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
 	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
 	struct hmm_mirror *mirror;
 
-	/* Bail out if hmm is in the process of being freed */
-	if (!kref_get_unless_zero(&hmm->kref))
-		return;
-
 	/*
 	 * Since hmm_range_register() holds the mmget() lock hmm_release() is
 	 * prevented as long as a range exists.
@@ -137,8 +73,6 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 			mirror->ops->release(mirror);
 	}
 	up_read(&hmm->mirrors_sem);
-
-	hmm_put(hmm);
 }
 
 static void notifiers_decrement(struct hmm *hmm)
@@ -165,23 +99,14 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
 {
 	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
 	struct hmm_mirror *mirror;
-	struct hmm_update update;
 	struct hmm_range *range;
 	unsigned long flags;
 	int ret = 0;
 
-	if (!kref_get_unless_zero(&hmm->kref))
-		return 0;
-
-	update.start = nrange->start;
-	update.end = nrange->end;
-	update.event = HMM_UPDATE_INVALIDATE;
-	update.blockable = mmu_notifier_range_blockable(nrange);
-
 	spin_lock_irqsave(&hmm->ranges_lock, flags);
 	hmm->notifiers++;
 	list_for_each_entry(range, &hmm->ranges, list) {
-		if (update.end < range->start || update.start >= range->end)
+		if (nrange->end < range->start || nrange->start >= range->end)
 			continue;
 
 		range->valid = false;
@@ -198,9 +123,10 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
 	list_for_each_entry(mirror, &hmm->mirrors, list) {
 		int rc;
 
-		rc = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
+		rc = mirror->ops->sync_cpu_device_pagetables(mirror, nrange);
 		if (rc) {
-			if (WARN_ON(update.blockable || rc != -EAGAIN))
+			if (WARN_ON(mmu_notifier_range_blockable(nrange) ||
+			    rc != -EAGAIN))
 				continue;
 			ret = -EAGAIN;
 			break;
@@ -211,7 +137,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
 out:
 	if (ret)
 		notifiers_decrement(hmm);
-	hmm_put(hmm);
 	return ret;
 }
 
@@ -220,17 +145,15 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
 {
 	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
 
-	if (!kref_get_unless_zero(&hmm->kref))
-		return;
-
 	notifiers_decrement(hmm);
-	hmm_put(hmm);
 }
 
 static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
 	.release		= hmm_release,
 	.invalidate_range_start	= hmm_invalidate_range_start,
 	.invalidate_range_end	= hmm_invalidate_range_end,
+	.alloc_notifier		= hmm_alloc_notifier,
+	.free_notifier		= hmm_free_notifier,
 };
 
 /*
@@ -242,18 +165,27 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
  *
  * To start mirroring a process address space, the device driver must register
  * an HMM mirror struct.
+ *
+ * The caller cannot unregister the hmm_mirror while any ranges are
+ * registered.
+ *
+ * Callers using this function must put a call to mmu_notifier_synchronize()
+ * in their module exit functions.
  */
 int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
 {
+	struct mmu_notifier *mn;
+
 	lockdep_assert_held_write(&mm->mmap_sem);
 
 	/* Sanity check */
 	if (!mm || !mirror || !mirror->ops)
 		return -EINVAL;
 
-	mirror->hmm = hmm_get_or_create(mm);
-	if (!mirror->hmm)
-		return -ENOMEM;
+	mn = mmu_notifier_get_locked(&hmm_mmu_notifier_ops, mm);
+	if (IS_ERR(mn))
+		return PTR_ERR(mn);
+	mirror->hmm = container_of(mn, struct hmm, mmu_notifier);
 
 	down_write(&mirror->hmm->mirrors_sem);
 	list_add(&mirror->list, &mirror->hmm->mirrors);
@@ -277,7 +209,7 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror)
 	down_write(&hmm->mirrors_sem);
 	list_del(&mirror->list);
 	up_write(&hmm->mirrors_sem);
-	hmm_put(hmm);
+	mmu_notifier_put(&hmm->mmu_notifier);
 }
 EXPORT_SYMBOL(hmm_mirror_unregister);
 
@@ -285,8 +217,7 @@ struct hmm_vma_walk {
 	struct hmm_range	*range;
 	struct dev_pagemap	*pgmap;
 	unsigned long		last;
-	bool			fault;
-	bool			block;
+	unsigned int		flags;
 };
 
 static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
@@ -298,17 +229,27 @@ static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
 	struct vm_area_struct *vma = walk->vma;
 	vm_fault_t ret;
 
-	flags |= hmm_vma_walk->block ? 0 : FAULT_FLAG_ALLOW_RETRY;
-	flags |= write_fault ? FAULT_FLAG_WRITE : 0;
+	if (!vma)
+		goto err;
+
+	if (hmm_vma_walk->flags & HMM_FAULT_ALLOW_RETRY)
+		flags |= FAULT_FLAG_ALLOW_RETRY;
+	if (write_fault)
+		flags |= FAULT_FLAG_WRITE;
+
 	ret = handle_mm_fault(vma, addr, flags);
-	if (ret & VM_FAULT_RETRY)
+	if (ret & VM_FAULT_RETRY) {
+		/* Note, handle_mm_fault did up_read(&mm->mmap_sem)) */
 		return -EAGAIN;
-	if (ret & VM_FAULT_ERROR) {
-		*pfn = range->values[HMM_PFN_ERROR];
-		return -EFAULT;
 	}
+	if (ret & VM_FAULT_ERROR)
+		goto err;
 
 	return -EBUSY;
+
+err:
+	*pfn = range->values[HMM_PFN_ERROR];
+	return -EFAULT;
 }
 
 static int hmm_pfns_bad(unsigned long addr,
@@ -328,8 +269,8 @@ static int hmm_pfns_bad(unsigned long addr,
 }
 
 /*
- * hmm_vma_walk_hole() - handle a range lacking valid pmd or pte(s)
- * @start: range virtual start address (inclusive)
+ * hmm_vma_walk_hole_() - handle a range lacking valid pmd or pte(s)
+ * @addr: range virtual start address (inclusive)
  * @end: range virtual end address (exclusive)
  * @fault: should we fault or not ?
  * @write_fault: write fault ?
@@ -346,13 +287,15 @@ static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
 	struct hmm_range *range = hmm_vma_walk->range;
 	uint64_t *pfns = range->pfns;
-	unsigned long i, page_size;
+	unsigned long i;
 
 	hmm_vma_walk->last = addr;
-	page_size = hmm_range_page_size(range);
-	i = (addr - range->start) >> range->page_shift;
+	i = (addr - range->start) >> PAGE_SHIFT;
+
+	if (write_fault && walk->vma && !(walk->vma->vm_flags & VM_WRITE))
+		return -EPERM;
 
-	for (; addr < end; addr += page_size, i++) {
+	for (; addr < end; addr += PAGE_SIZE, i++) {
 		pfns[i] = range->values[HMM_PFN_NONE];
 		if (fault || write_fault) {
 			int ret;
@@ -373,15 +316,15 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
 {
 	struct hmm_range *range = hmm_vma_walk->range;
 
-	if (!hmm_vma_walk->fault)
+	if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT)
 		return;
 
 	/*
 	 * So we not only consider the individual per page request we also
 	 * consider the default flags requested for the range. The API can
-	 * be use in 2 fashions. The first one where the HMM user coalesce
-	 * multiple page fault into one request and set flags per pfns for
-	 * of those faults. The second one where the HMM user want to pre-
+	 * be used 2 ways. The first one where the HMM user coalesces
+	 * multiple page faults into one request and sets flags per pfn for
+	 * those faults. The second one where the HMM user wants to pre-
 	 * fault a range with specific flags. For the latter one it is a
 	 * waste to have the user pre-fill the pfn arrays with a default
 	 * flags value.
@@ -391,7 +334,7 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
 	/* We aren't ask to do anything ... */
 	if (!(pfns & range->flags[HMM_PFN_VALID]))
 		return;
-	/* If this is device memory than only fault if explicitly requested */
+	/* If this is device memory then only fault if explicitly requested */
 	if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
 		/* Do we fault on device memory ? */
 		if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
@@ -418,7 +361,7 @@ static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
 {
 	unsigned long i;
 
-	if (!hmm_vma_walk->fault) {
+	if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT) {
 		*fault = *write_fault = false;
 		return;
 	}
@@ -458,22 +401,10 @@ static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
 				range->flags[HMM_PFN_VALID];
 }
 
-static inline uint64_t pud_to_hmm_pfn_flags(struct hmm_range *range, pud_t pud)
-{
-	if (!pud_present(pud))
-		return 0;
-	return pud_write(pud) ? range->flags[HMM_PFN_VALID] |
-				range->flags[HMM_PFN_WRITE] :
-				range->flags[HMM_PFN_VALID];
-}
-
-static int hmm_vma_handle_pmd(struct mm_walk *walk,
-			      unsigned long addr,
-			      unsigned long end,
-			      uint64_t *pfns,
-			      pmd_t pmd)
-{
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
+static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
+		unsigned long end, uint64_t *pfns, pmd_t pmd)
+{
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
 	struct hmm_range *range = hmm_vma_walk->range;
 	unsigned long pfn, npages, i;
@@ -488,7 +419,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk,
 	if (pmd_protnone(pmd) || fault || write_fault)
 		return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
 
-	pfn = pmd_pfn(pmd) + pte_index(addr);
+	pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
 	for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
 		if (pmd_devmap(pmd)) {
 			hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
@@ -504,11 +435,12 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk,
 	}
 	hmm_vma_walk->last = end;
 	return 0;
-#else
-	/* If THP is not enabled then we should never reach that code ! */
-	return -EINVAL;
-#endif
 }
+#else /* CONFIG_TRANSPARENT_HUGEPAGE */
+/* stub to allow the code below to compile */
+int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
+		unsigned long end, uint64_t *pfns, pmd_t pmd);
+#endif /* CONFIG_TRANSPARENT_HUGEPAGE */
 
 static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
 {
@@ -525,7 +457,6 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
 {
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
 	struct hmm_range *range = hmm_vma_walk->range;
-	struct vm_area_struct *vma = walk->vma;
 	bool fault, write_fault;
 	uint64_t cpu_flags;
 	pte_t pte = *ptep;
@@ -546,6 +477,9 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
 		swp_entry_t entry = pte_to_swp_entry(pte);
 
 		if (!non_swap_entry(entry)) {
+			cpu_flags = pte_to_hmm_pfn_flags(range, pte);
+			hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
+					   &fault, &write_fault);
 			if (fault || write_fault)
 				goto fault;
 			return 0;
@@ -574,8 +508,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
 			if (fault || write_fault) {
 				pte_unmap(ptep);
 				hmm_vma_walk->last = addr;
-				migration_entry_wait(vma->vm_mm,
-						     pmdp, addr);
+				migration_entry_wait(walk->mm, pmdp, addr);
 				return -EBUSY;
 			}
 			return 0;
@@ -623,21 +556,16 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
 {
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
 	struct hmm_range *range = hmm_vma_walk->range;
-	struct vm_area_struct *vma = walk->vma;
 	uint64_t *pfns = range->pfns;
 	unsigned long addr = start, i;
 	pte_t *ptep;
 	pmd_t pmd;
 
-
 again:
 	pmd = READ_ONCE(*pmdp);
 	if (pmd_none(pmd))
 		return hmm_vma_walk_hole(start, end, walk);
 
-	if (pmd_huge(pmd) && (range->vma->vm_flags & VM_HUGETLB))
-		return hmm_pfns_bad(start, end, walk);
-
 	if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
 		bool fault, write_fault;
 		unsigned long npages;
@@ -651,7 +579,7 @@ again:
 				     0, &fault, &write_fault);
 		if (fault || write_fault) {
 			hmm_vma_walk->last = addr;
-			pmd_migration_entry_wait(vma->vm_mm, pmdp);
+			pmd_migration_entry_wait(walk->mm, pmdp);
 			return -EBUSY;
 		}
 		return 0;
@@ -660,11 +588,11 @@ again:
 
 	if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
 		/*
-		 * No need to take pmd_lock here, even if some other threads
+		 * No need to take pmd_lock here, even if some other thread
 		 * is splitting the huge pmd we will get that event through
 		 * mmu_notifier callback.
 		 *
-		 * So just read pmd value and check again its a transparent
+		 * So just read pmd value and check again it's a transparent
 		 * huge or device mapping one and compute corresponding pfn
 		 * values.
 		 */
@@ -678,7 +606,7 @@ again:
 	}
 
 	/*
-	 * We have handled all the valid case above ie either none, migration,
+	 * We have handled all the valid cases above ie either none, migration,
 	 * huge or transparent huge. At this point either it is a valid pmd
 	 * entry pointing to pte directory or it is a bad pmd that will not
 	 * recover.
@@ -714,10 +642,19 @@ again:
 	return 0;
 }
 
-static int hmm_vma_walk_pud(pud_t *pudp,
-			    unsigned long start,
-			    unsigned long end,
-			    struct mm_walk *walk)
+#if defined(CONFIG_ARCH_HAS_PTE_DEVMAP) && \
+    defined(CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD)
+static inline uint64_t pud_to_hmm_pfn_flags(struct hmm_range *range, pud_t pud)
+{
+	if (!pud_present(pud))
+		return 0;
+	return pud_write(pud) ? range->flags[HMM_PFN_VALID] |
+				range->flags[HMM_PFN_WRITE] :
+				range->flags[HMM_PFN_VALID];
+}
+
+static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
+		struct mm_walk *walk)
 {
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
 	struct hmm_range *range = hmm_vma_walk->range;
@@ -781,42 +718,29 @@ again:
 
 	return 0;
 }
+#else
+#define hmm_vma_walk_pud	NULL
+#endif
 
+#ifdef CONFIG_HUGETLB_PAGE
 static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
 				      unsigned long start, unsigned long end,
 				      struct mm_walk *walk)
 {
-#ifdef CONFIG_HUGETLB_PAGE
-	unsigned long addr = start, i, pfn, mask, size, pfn_inc;
+	unsigned long addr = start, i, pfn;
 	struct hmm_vma_walk *hmm_vma_walk = walk->private;
 	struct hmm_range *range = hmm_vma_walk->range;
 	struct vm_area_struct *vma = walk->vma;
-	struct hstate *h = hstate_vma(vma);
 	uint64_t orig_pfn, cpu_flags;
 	bool fault, write_fault;
 	spinlock_t *ptl;
 	pte_t entry;
 	int ret = 0;
 
-	size = 1UL << huge_page_shift(h);
-	mask = size - 1;
-	if (range->page_shift != PAGE_SHIFT) {
-		/* Make sure we are looking at full page. */
-		if (start & mask)
-			return -EINVAL;
-		if (end < (start + size))
-			return -EINVAL;
-		pfn_inc = size >> PAGE_SHIFT;
-	} else {
-		pfn_inc = 1;
-		size = PAGE_SIZE;
-	}
-
-
-	ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
+	ptl = huge_pte_lock(hstate_vma(vma), walk->mm, pte);
 	entry = huge_ptep_get(pte);
 
-	i = (start - range->start) >> range->page_shift;
+	i = (start - range->start) >> PAGE_SHIFT;
 	orig_pfn = range->pfns[i];
 	range->pfns[i] = range->values[HMM_PFN_NONE];
 	cpu_flags = pte_to_hmm_pfn_flags(range, entry);
@@ -828,8 +752,8 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
 		goto unlock;
 	}
 
-	pfn = pte_pfn(entry) + ((start & mask) >> range->page_shift);
-	for (; addr < end; addr += size, i++, pfn += pfn_inc)
+	pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
+	for (; addr < end; addr += PAGE_SIZE, i++, pfn++)
 		range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
 				 cpu_flags;
 	hmm_vma_walk->last = end;
@@ -841,10 +765,10 @@ unlock:
 		return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
 
 	return ret;
-#else /* CONFIG_HUGETLB_PAGE */
-	return -EINVAL;
-#endif
 }
+#else
+#define hmm_vma_walk_hugetlb_entry NULL
+#endif /* CONFIG_HUGETLB_PAGE */
 
 static void hmm_pfns_clear(struct hmm_range *range,
 			   uint64_t *pfns,
@@ -859,44 +783,32 @@ static void hmm_pfns_clear(struct hmm_range *range,
  * hmm_range_register() - start tracking change to CPU page table over a range
  * @range: range
  * @mm: the mm struct for the range of virtual address
- * @start: start virtual address (inclusive)
- * @end: end virtual address (exclusive)
- * @page_shift: expect page shift for the range
- * Returns 0 on success, -EFAULT if the address space is no longer valid
+ *
+ * Return: 0 on success, -EFAULT if the address space is no longer valid
  *
  * Track updates to the CPU page table see include/linux/hmm.h
  */
-int hmm_range_register(struct hmm_range *range,
-		       struct hmm_mirror *mirror,
-		       unsigned long start,
-		       unsigned long end,
-		       unsigned page_shift)
+int hmm_range_register(struct hmm_range *range, struct hmm_mirror *mirror)
 {
-	unsigned long mask = ((1UL << page_shift) - 1UL);
 	struct hmm *hmm = mirror->hmm;
 	unsigned long flags;
 
 	range->valid = false;
 	range->hmm = NULL;
 
-	if ((start & mask) || (end & mask))
+	if ((range->start & (PAGE_SIZE - 1)) || (range->end & (PAGE_SIZE - 1)))
 		return -EINVAL;
-	if (start >= end)
+	if (range->start >= range->end)
 		return -EINVAL;
 
-	range->page_shift = page_shift;
-	range->start = start;
-	range->end = end;
-
 	/* Prevent hmm_release() from running while the range is valid */
-	if (!mmget_not_zero(hmm->mm))
+	if (!mmget_not_zero(hmm->mmu_notifier.mm))
 		return -EFAULT;
 
 	/* Initialize range to track CPU page table updates. */
 	spin_lock_irqsave(&hmm->ranges_lock, flags);
 
 	range->hmm = hmm;
-	kref_get(&hmm->kref);
 	list_add(&range->list, &hmm->ranges);
 
 	/*
@@ -928,8 +840,7 @@ void hmm_range_unregister(struct hmm_range *range)
 	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
 
 	/* Drop reference taken by hmm_range_register() */
-	mmput(hmm->mm);
-	hmm_put(hmm);
+	mmput(hmm->mmu_notifier.mm);
 
 	/*
 	 * The range is now invalid and the ref on the hmm is dropped, so
@@ -941,105 +852,33 @@ void hmm_range_unregister(struct hmm_range *range)
 }
 EXPORT_SYMBOL(hmm_range_unregister);
 
-/*
- * hmm_range_snapshot() - snapshot CPU page table for a range
- * @range: range
- * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
- *          permission (for instance asking for write and range is read only),
- *          -EBUSY if you need to retry, -EFAULT invalid (ie either no valid
- *          vma or it is illegal to access that range), number of valid pages
- *          in range->pfns[] (from range start address).
- *
- * This snapshots the CPU page table for a range of virtual addresses. Snapshot
- * validity is tracked by range struct. See in include/linux/hmm.h for example
- * on how to use.
- */
-long hmm_range_snapshot(struct hmm_range *range)
-{
-	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
-	unsigned long start = range->start, end;
-	struct hmm_vma_walk hmm_vma_walk;
-	struct hmm *hmm = range->hmm;
-	struct vm_area_struct *vma;
-	struct mm_walk mm_walk;
-
-	lockdep_assert_held(&hmm->mm->mmap_sem);
-	do {
-		/* If range is no longer valid force retry. */
-		if (!range->valid)
-			return -EBUSY;
-
-		vma = find_vma(hmm->mm, start);
-		if (vma == NULL || (vma->vm_flags & device_vma))
-			return -EFAULT;
-
-		if (is_vm_hugetlb_page(vma)) {
-			if (huge_page_shift(hstate_vma(vma)) !=
-				    range->page_shift &&
-			    range->page_shift != PAGE_SHIFT)
-				return -EINVAL;
-		} else {
-			if (range->page_shift != PAGE_SHIFT)
-				return -EINVAL;
-		}
-
-		if (!(vma->vm_flags & VM_READ)) {
-			/*
-			 * If vma do not allow read access, then assume that it
-			 * does not allow write access, either. HMM does not
-			 * support architecture that allow write without read.
-			 */
-			hmm_pfns_clear(range, range->pfns,
-				range->start, range->end);
-			return -EPERM;
-		}
-
-		range->vma = vma;
-		hmm_vma_walk.pgmap = NULL;
-		hmm_vma_walk.last = start;
-		hmm_vma_walk.fault = false;
-		hmm_vma_walk.range = range;
-		mm_walk.private = &hmm_vma_walk;
-		end = min(range->end, vma->vm_end);
-
-		mm_walk.vma = vma;
-		mm_walk.mm = vma->vm_mm;
-		mm_walk.pte_entry = NULL;
-		mm_walk.test_walk = NULL;
-		mm_walk.hugetlb_entry = NULL;
-		mm_walk.pud_entry = hmm_vma_walk_pud;
-		mm_walk.pmd_entry = hmm_vma_walk_pmd;
-		mm_walk.pte_hole = hmm_vma_walk_hole;
-		mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
-
-		walk_page_range(start, end, &mm_walk);
-		start = end;
-	} while (start < range->end);
-
-	return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
-}
-EXPORT_SYMBOL(hmm_range_snapshot);
+static const struct mm_walk_ops hmm_walk_ops = {
+	.pud_entry	= hmm_vma_walk_pud,
+	.pmd_entry	= hmm_vma_walk_pmd,
+	.pte_hole	= hmm_vma_walk_hole,
+	.hugetlb_entry	= hmm_vma_walk_hugetlb_entry,
+};
 
-/*
- * hmm_range_fault() - try to fault some address in a virtual address range
- * @range: range being faulted
- * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
- * Return: number of valid pages in range->pfns[] (from range start
- *          address). This may be zero. If the return value is negative,
- *          then one of the following values may be returned:
+/**
+ * hmm_range_fault - try to fault some address in a virtual address range
+ * @range:	range being faulted
+ * @flags:	HMM_FAULT_* flags
  *
- *           -EINVAL  invalid arguments or mm or virtual address are in an
- *                    invalid vma (for instance device file vma).
- *           -ENOMEM: Out of memory.
- *           -EPERM:  Invalid permission (for instance asking for write and
- *                    range is read only).
- *           -EAGAIN: If you need to retry and mmap_sem was drop. This can only
- *                    happens if block argument is false.
- *           -EBUSY:  If the the range is being invalidated and you should wait
- *                    for invalidation to finish.
- *           -EFAULT: Invalid (ie either no valid vma or it is illegal to access
- *                    that range), number of valid pages in range->pfns[] (from
- *                    range start address).
+ * Return: the number of valid pages in range->pfns[] (from range start
+ * address), which may be zero.  On error one of the following status codes
+ * can be returned:
+ *
+ * -EINVAL:	Invalid arguments or mm or virtual address is in an invalid vma
+ *		(e.g., device file vma).
+ * -ENOMEM:	Out of memory.
+ * -EPERM:	Invalid permission (e.g., asking for write and range is read
+ *		only).
+ * -EAGAIN:	A page fault needs to be retried and mmap_sem was dropped.
+ * -EBUSY:	The range has been invalidated and the caller needs to wait for
+ *		the invalidation to finish.
+ * -EFAULT:	Invalid (i.e., either no valid vma or it is illegal to access
+ *		that range) number of valid pages in range->pfns[] (from
+ *              range start address).
  *
  * This is similar to a regular CPU page fault except that it will not trigger
  * any memory migration if the memory being faulted is not accessible by CPUs
@@ -1048,37 +887,26 @@ EXPORT_SYMBOL(hmm_range_snapshot);
  * On error, for one virtual address in the range, the function will mark the
  * corresponding HMM pfn entry with an error flag.
  */
-long hmm_range_fault(struct hmm_range *range, bool block)
+long hmm_range_fault(struct hmm_range *range, unsigned int flags)
 {
 	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
 	unsigned long start = range->start, end;
 	struct hmm_vma_walk hmm_vma_walk;
 	struct hmm *hmm = range->hmm;
 	struct vm_area_struct *vma;
-	struct mm_walk mm_walk;
 	int ret;
 
-	lockdep_assert_held(&hmm->mm->mmap_sem);
+	lockdep_assert_held(&hmm->mmu_notifier.mm->mmap_sem);
 
 	do {
 		/* If range is no longer valid force retry. */
 		if (!range->valid)
 			return -EBUSY;
 
-		vma = find_vma(hmm->mm, start);
+		vma = find_vma(hmm->mmu_notifier.mm, start);
 		if (vma == NULL || (vma->vm_flags & device_vma))
 			return -EFAULT;
 
-		if (is_vm_hugetlb_page(vma)) {
-			if (huge_page_shift(hstate_vma(vma)) !=
-			    range->page_shift &&
-			    range->page_shift != PAGE_SHIFT)
-				return -EINVAL;
-		} else {
-			if (range->page_shift != PAGE_SHIFT)
-				return -EINVAL;
-		}
-
 		if (!(vma->vm_flags & VM_READ)) {
 			/*
 			 * If vma do not allow read access, then assume that it
@@ -1090,27 +918,18 @@ long hmm_range_fault(struct hmm_range *range, bool block)
 			return -EPERM;
 		}
 
-		range->vma = vma;
 		hmm_vma_walk.pgmap = NULL;
 		hmm_vma_walk.last = start;
-		hmm_vma_walk.fault = true;
-		hmm_vma_walk.block = block;
+		hmm_vma_walk.flags = flags;
 		hmm_vma_walk.range = range;
-		mm_walk.private = &hmm_vma_walk;
 		end = min(range->end, vma->vm_end);
 
-		mm_walk.vma = vma;
-		mm_walk.mm = vma->vm_mm;
-		mm_walk.pte_entry = NULL;
-		mm_walk.test_walk = NULL;
-		mm_walk.hugetlb_entry = NULL;
-		mm_walk.pud_entry = hmm_vma_walk_pud;
-		mm_walk.pmd_entry = hmm_vma_walk_pmd;
-		mm_walk.pte_hole = hmm_vma_walk_hole;
-		mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
+		walk_page_range(vma->vm_mm, start, end, &hmm_walk_ops,
+				&hmm_vma_walk);
 
 		do {
-			ret = walk_page_range(start, end, &mm_walk);
+			ret = walk_page_range(vma->vm_mm, start, end,
+					&hmm_walk_ops, &hmm_vma_walk);
 			start = hmm_vma_walk.last;
 
 			/* Keep trying while the range is valid. */
@@ -1133,25 +952,22 @@ long hmm_range_fault(struct hmm_range *range, bool block)
 EXPORT_SYMBOL(hmm_range_fault);
 
 /**
- * hmm_range_dma_map() - hmm_range_fault() and dma map page all in one.
- * @range: range being faulted
- * @device: device against to dma map page to
- * @daddrs: dma address of mapped pages
- * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
- * Return: number of pages mapped on success, -EAGAIN if mmap_sem have been
- *          drop and you need to try again, some other error value otherwise
+ * hmm_range_dma_map - hmm_range_fault() and dma map page all in one.
+ * @range:	range being faulted
+ * @device:	device to map page to
+ * @daddrs:	array of dma addresses for the mapped pages
+ * @flags:	HMM_FAULT_*
  *
- * Note same usage pattern as hmm_range_fault().
+ * Return: the number of pages mapped on success (including zero), or any
+ * status return from hmm_range_fault() otherwise.
  */
-long hmm_range_dma_map(struct hmm_range *range,
-		       struct device *device,
-		       dma_addr_t *daddrs,
-		       bool block)
+long hmm_range_dma_map(struct hmm_range *range, struct device *device,
+		dma_addr_t *daddrs, unsigned int flags)
 {
 	unsigned long i, npages, mapped;
 	long ret;
 
-	ret = hmm_range_fault(range, block);
+	ret = hmm_range_fault(range, flags);
 	if (ret <= 0)
 		return ret ? ret : -EBUSY;
 
@@ -1222,7 +1038,6 @@ EXPORT_SYMBOL(hmm_range_dma_map);
 /**
  * hmm_range_dma_unmap() - unmap range of that was map with hmm_range_dma_map()
  * @range: range being unmapped
- * @vma: the vma against which the range (optional)
  * @device: device against which dma map was done
  * @daddrs: dma address of mapped pages
  * @dirty: dirty page if it had the write flag set
@@ -1234,7 +1049,6 @@ EXPORT_SYMBOL(hmm_range_dma_map);
  * concurrent mmu notifier or sync_cpu_device_pagetables() to make progress.
  */
 long hmm_range_dma_unmap(struct hmm_range *range,
-			 struct vm_area_struct *vma,
 			 struct device *device,
 			 dma_addr_t *daddrs,
 			 bool dirty)
diff --git a/mm/madvise.c b/mm/madvise.c
index bac973b9f2cc..88babcc384b9 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -21,6 +21,7 @@
 #include <linux/file.h>
 #include <linux/blkdev.h>
 #include <linux/backing-dev.h>
+#include <linux/pagewalk.h>
 #include <linux/swap.h>
 #include <linux/swapops.h>
 #include <linux/shmem_fs.h>
@@ -226,19 +227,9 @@ static int swapin_walk_pmd_entry(pmd_t *pmd, unsigned long start,
 	return 0;
 }
 
-static void force_swapin_readahead(struct vm_area_struct *vma,
-		unsigned long start, unsigned long end)
-{
-	struct mm_walk walk = {
-		.mm = vma->vm_mm,
-		.pmd_entry = swapin_walk_pmd_entry,
-		.private = vma,
-	};
-
-	walk_page_range(start, end, &walk);
-
-	lru_add_drain();	/* Push any new pages onto the LRU now */
-}
+static const struct mm_walk_ops swapin_walk_ops = {
+	.pmd_entry		= swapin_walk_pmd_entry,
+};
 
 static void force_shm_swapin_readahead(struct vm_area_struct *vma,
 		unsigned long start, unsigned long end,
@@ -281,7 +272,8 @@ static long madvise_willneed(struct vm_area_struct *vma,
 	*prev = vma;
 #ifdef CONFIG_SWAP
 	if (!file) {
-		force_swapin_readahead(vma, start, end);
+		walk_page_range(vma->vm_mm, start, end, &swapin_walk_ops, vma);
+		lru_add_drain(); /* Push any new pages onto the LRU now */
 		return 0;
 	}
 
@@ -450,20 +442,9 @@ next:
 	return 0;
 }
 
-static void madvise_free_page_range(struct mmu_gather *tlb,
-			     struct vm_area_struct *vma,
-			     unsigned long addr, unsigned long end)
-{
-	struct mm_walk free_walk = {
-		.pmd_entry = madvise_free_pte_range,
-		.mm = vma->vm_mm,
-		.private = tlb,
-	};
-
-	tlb_start_vma(tlb, vma);
-	walk_page_range(addr, end, &free_walk);
-	tlb_end_vma(tlb, vma);
-}
+static const struct mm_walk_ops madvise_free_walk_ops = {
+	.pmd_entry		= madvise_free_pte_range,
+};
 
 static int madvise_free_single_vma(struct vm_area_struct *vma,
 			unsigned long start_addr, unsigned long end_addr)
@@ -490,7 +471,10 @@ static int madvise_free_single_vma(struct vm_area_struct *vma,
 	update_hiwater_rss(mm);
 
 	mmu_notifier_invalidate_range_start(&range);
-	madvise_free_page_range(&tlb, vma, range.start, range.end);
+	tlb_start_vma(&tlb, vma);
+	walk_page_range(vma->vm_mm, range.start, range.end,
+			&madvise_free_walk_ops, &tlb);
+	tlb_end_vma(&tlb, vma);
 	mmu_notifier_invalidate_range_end(&range);
 	tlb_finish_mmu(&tlb, range.start, range.end);
 
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 597d58101872..f3c15bb07cce 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -25,7 +25,7 @@
 #include <linux/page_counter.h>
 #include <linux/memcontrol.h>
 #include <linux/cgroup.h>
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/sched/mm.h>
 #include <linux/shmem_fs.h>
 #include <linux/hugetlb.h>
@@ -5499,17 +5499,16 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
 	return 0;
 }
 
+static const struct mm_walk_ops precharge_walk_ops = {
+	.pmd_entry	= mem_cgroup_count_precharge_pte_range,
+};
+
 static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
 {
 	unsigned long precharge;
 
-	struct mm_walk mem_cgroup_count_precharge_walk = {
-		.pmd_entry = mem_cgroup_count_precharge_pte_range,
-		.mm = mm,
-	};
 	down_read(&mm->mmap_sem);
-	walk_page_range(0, mm->highest_vm_end,
-			&mem_cgroup_count_precharge_walk);
+	walk_page_range(mm, 0, mm->highest_vm_end, &precharge_walk_ops, NULL);
 	up_read(&mm->mmap_sem);
 
 	precharge = mc.precharge;
@@ -5778,13 +5777,12 @@ put:			/* get_mctgt_type() gets the page */
 	return ret;
 }
 
+static const struct mm_walk_ops charge_walk_ops = {
+	.pmd_entry	= mem_cgroup_move_charge_pte_range,
+};
+
 static void mem_cgroup_move_charge(void)
 {
-	struct mm_walk mem_cgroup_move_charge_walk = {
-		.pmd_entry = mem_cgroup_move_charge_pte_range,
-		.mm = mc.mm,
-	};
-
 	lru_add_drain_all();
 	/*
 	 * Signal lock_page_memcg() to take the memcg's move_lock
@@ -5810,7 +5808,8 @@ retry:
 	 * When we have consumed all precharges and failed in doing
 	 * additional charge, the page walk just aborts.
 	 */
-	walk_page_range(0, mc.mm->highest_vm_end, &mem_cgroup_move_charge_walk);
+	walk_page_range(mc.mm, 0, mc.mm->highest_vm_end, &charge_walk_ops,
+			NULL);
 
 	up_read(&mc.mm->mmap_sem);
 	atomic_dec(&mc.from->moving_account);
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 65e0874fce17..f000771558d8 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -68,7 +68,7 @@
 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
 
 #include <linux/mempolicy.h>
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/highmem.h>
 #include <linux/hugetlb.h>
 #include <linux/kernel.h>
@@ -655,6 +655,12 @@ static int queue_pages_test_walk(unsigned long start, unsigned long end,
 	return 1;
 }
 
+static const struct mm_walk_ops queue_pages_walk_ops = {
+	.hugetlb_entry		= queue_pages_hugetlb,
+	.pmd_entry		= queue_pages_pte_range,
+	.test_walk		= queue_pages_test_walk,
+};
+
 /*
  * Walk through page tables and collect pages to be migrated.
  *
@@ -679,15 +685,8 @@ queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
 		.nmask = nodes,
 		.prev = NULL,
 	};
-	struct mm_walk queue_pages_walk = {
-		.hugetlb_entry = queue_pages_hugetlb,
-		.pmd_entry = queue_pages_pte_range,
-		.test_walk = queue_pages_test_walk,
-		.mm = mm,
-		.private = &qp,
-	};
 
-	return walk_page_range(start, end, &queue_pages_walk);
+	return walk_page_range(mm, start, end, &queue_pages_walk_ops, &qp);
 }
 
 /*
diff --git a/mm/memremap.c b/mm/memremap.c
index ed70c4e8e52a..32c79b51af86 100644
--- a/mm/memremap.c
+++ b/mm/memremap.c
@@ -21,13 +21,13 @@ DEFINE_STATIC_KEY_FALSE(devmap_managed_key);
 EXPORT_SYMBOL(devmap_managed_key);
 static atomic_t devmap_managed_enable;
 
-static void devmap_managed_enable_put(void *data)
+static void devmap_managed_enable_put(void)
 {
 	if (atomic_dec_and_test(&devmap_managed_enable))
 		static_branch_disable(&devmap_managed_key);
 }
 
-static int devmap_managed_enable_get(struct device *dev, struct dev_pagemap *pgmap)
+static int devmap_managed_enable_get(struct dev_pagemap *pgmap)
 {
 	if (!pgmap->ops || !pgmap->ops->page_free) {
 		WARN(1, "Missing page_free method\n");
@@ -36,13 +36,16 @@ static int devmap_managed_enable_get(struct device *dev, struct dev_pagemap *pgm
 
 	if (atomic_inc_return(&devmap_managed_enable) == 1)
 		static_branch_enable(&devmap_managed_key);
-	return devm_add_action_or_reset(dev, devmap_managed_enable_put, NULL);
+	return 0;
 }
 #else
-static int devmap_managed_enable_get(struct device *dev, struct dev_pagemap *pgmap)
+static int devmap_managed_enable_get(struct dev_pagemap *pgmap)
 {
 	return -EINVAL;
 }
+static void devmap_managed_enable_put(void)
+{
+}
 #endif /* CONFIG_DEV_PAGEMAP_OPS */
 
 static void pgmap_array_delete(struct resource *res)
@@ -99,10 +102,8 @@ static void dev_pagemap_cleanup(struct dev_pagemap *pgmap)
 		pgmap->ref = NULL;
 }
 
-static void devm_memremap_pages_release(void *data)
+void memunmap_pages(struct dev_pagemap *pgmap)
 {
-	struct dev_pagemap *pgmap = data;
-	struct device *dev = pgmap->dev;
 	struct resource *res = &pgmap->res;
 	unsigned long pfn;
 	int nid;
@@ -129,8 +130,14 @@ static void devm_memremap_pages_release(void *data)
 
 	untrack_pfn(NULL, PHYS_PFN(res->start), resource_size(res));
 	pgmap_array_delete(res);
-	dev_WARN_ONCE(dev, pgmap->altmap.alloc,
-		      "%s: failed to free all reserved pages\n", __func__);
+	WARN_ONCE(pgmap->altmap.alloc, "failed to free all reserved pages\n");
+	devmap_managed_enable_put();
+}
+EXPORT_SYMBOL_GPL(memunmap_pages);
+
+static void devm_memremap_pages_release(void *data)
+{
+	memunmap_pages(data);
 }
 
 static void dev_pagemap_percpu_release(struct percpu_ref *ref)
@@ -141,27 +148,12 @@ static void dev_pagemap_percpu_release(struct percpu_ref *ref)
 	complete(&pgmap->done);
 }
 
-/**
- * devm_memremap_pages - remap and provide memmap backing for the given resource
- * @dev: hosting device for @res
- * @pgmap: pointer to a struct dev_pagemap
- *
- * Notes:
- * 1/ At a minimum the res and type members of @pgmap must be initialized
- *    by the caller before passing it to this function
- *
- * 2/ The altmap field may optionally be initialized, in which case
- *    PGMAP_ALTMAP_VALID must be set in pgmap->flags.
- *
- * 3/ The ref field may optionally be provided, in which pgmap->ref must be
- *    'live' on entry and will be killed and reaped at
- *    devm_memremap_pages_release() time, or if this routine fails.
- *
- * 4/ res is expected to be a host memory range that could feasibly be
- *    treated as a "System RAM" range, i.e. not a device mmio range, but
- *    this is not enforced.
+/*
+ * Not device managed version of dev_memremap_pages, undone by
+ * memunmap_pages().  Please use dev_memremap_pages if you have a struct
+ * device available.
  */
-void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
+void *memremap_pages(struct dev_pagemap *pgmap, int nid)
 {
 	struct resource *res = &pgmap->res;
 	struct dev_pagemap *conflict_pgmap;
@@ -172,7 +164,7 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
 		.altmap = pgmap_altmap(pgmap),
 	};
 	pgprot_t pgprot = PAGE_KERNEL;
-	int error, nid, is_ram;
+	int error, is_ram;
 	bool need_devmap_managed = true;
 
 	switch (pgmap->type) {
@@ -220,14 +212,14 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
 	}
 
 	if (need_devmap_managed) {
-		error = devmap_managed_enable_get(dev, pgmap);
+		error = devmap_managed_enable_get(pgmap);
 		if (error)
 			return ERR_PTR(error);
 	}
 
 	conflict_pgmap = get_dev_pagemap(PHYS_PFN(res->start), NULL);
 	if (conflict_pgmap) {
-		dev_WARN(dev, "Conflicting mapping in same section\n");
+		WARN(1, "Conflicting mapping in same section\n");
 		put_dev_pagemap(conflict_pgmap);
 		error = -ENOMEM;
 		goto err_array;
@@ -235,7 +227,7 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
 
 	conflict_pgmap = get_dev_pagemap(PHYS_PFN(res->end), NULL);
 	if (conflict_pgmap) {
-		dev_WARN(dev, "Conflicting mapping in same section\n");
+		WARN(1, "Conflicting mapping in same section\n");
 		put_dev_pagemap(conflict_pgmap);
 		error = -ENOMEM;
 		goto err_array;
@@ -251,14 +243,11 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
 		goto err_array;
 	}
 
-	pgmap->dev = dev;
-
 	error = xa_err(xa_store_range(&pgmap_array, PHYS_PFN(res->start),
 				PHYS_PFN(res->end), pgmap, GFP_KERNEL));
 	if (error)
 		goto err_array;
 
-	nid = dev_to_node(dev);
 	if (nid < 0)
 		nid = numa_mem_id();
 
@@ -314,12 +303,6 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
 				PHYS_PFN(res->start),
 				PHYS_PFN(resource_size(res)), pgmap);
 	percpu_ref_get_many(pgmap->ref, pfn_end(pgmap) - pfn_first(pgmap));
-
-	error = devm_add_action_or_reset(dev, devm_memremap_pages_release,
-			pgmap);
-	if (error)
-		return ERR_PTR(error);
-
 	return __va(res->start);
 
  err_add_memory:
@@ -331,8 +314,46 @@ void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
  err_array:
 	dev_pagemap_kill(pgmap);
 	dev_pagemap_cleanup(pgmap);
+	devmap_managed_enable_put();
 	return ERR_PTR(error);
 }
+EXPORT_SYMBOL_GPL(memremap_pages);
+
+/**
+ * devm_memremap_pages - remap and provide memmap backing for the given resource
+ * @dev: hosting device for @res
+ * @pgmap: pointer to a struct dev_pagemap
+ *
+ * Notes:
+ * 1/ At a minimum the res and type members of @pgmap must be initialized
+ *    by the caller before passing it to this function
+ *
+ * 2/ The altmap field may optionally be initialized, in which case
+ *    PGMAP_ALTMAP_VALID must be set in pgmap->flags.
+ *
+ * 3/ The ref field may optionally be provided, in which pgmap->ref must be
+ *    'live' on entry and will be killed and reaped at
+ *    devm_memremap_pages_release() time, or if this routine fails.
+ *
+ * 4/ res is expected to be a host memory range that could feasibly be
+ *    treated as a "System RAM" range, i.e. not a device mmio range, but
+ *    this is not enforced.
+ */
+void *devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
+{
+	int error;
+	void *ret;
+
+	ret = memremap_pages(pgmap, dev_to_node(dev));
+	if (IS_ERR(ret))
+		return ret;
+
+	error = devm_add_action_or_reset(dev, devm_memremap_pages_release,
+			pgmap);
+	if (error)
+		return ERR_PTR(error);
+	return ret;
+}
 EXPORT_SYMBOL_GPL(devm_memremap_pages);
 
 void devm_memunmap_pages(struct device *dev, struct dev_pagemap *pgmap)
diff --git a/mm/migrate.c b/mm/migrate.c
index a42858d8e00b..9f4ed4e985c1 100644
--- a/mm/migrate.c
+++ b/mm/migrate.c
@@ -38,6 +38,7 @@
 #include <linux/hugetlb.h>
 #include <linux/hugetlb_cgroup.h>
 #include <linux/gfp.h>
+#include <linux/pagewalk.h>
 #include <linux/pfn_t.h>
 #include <linux/memremap.h>
 #include <linux/userfaultfd_k.h>
@@ -2119,17 +2120,7 @@ out_unlock:
 
 #endif /* CONFIG_NUMA */
 
-#if defined(CONFIG_MIGRATE_VMA_HELPER)
-struct migrate_vma {
-	struct vm_area_struct	*vma;
-	unsigned long		*dst;
-	unsigned long		*src;
-	unsigned long		cpages;
-	unsigned long		npages;
-	unsigned long		start;
-	unsigned long		end;
-};
-
+#ifdef CONFIG_DEVICE_PRIVATE
 static int migrate_vma_collect_hole(unsigned long start,
 				    unsigned long end,
 				    struct mm_walk *walk)
@@ -2249,8 +2240,8 @@ again:
 				goto next;
 
 			page = device_private_entry_to_page(entry);
-			mpfn = migrate_pfn(page_to_pfn(page))|
-				MIGRATE_PFN_DEVICE | MIGRATE_PFN_MIGRATE;
+			mpfn = migrate_pfn(page_to_pfn(page)) |
+					MIGRATE_PFN_MIGRATE;
 			if (is_write_device_private_entry(entry))
 				mpfn |= MIGRATE_PFN_WRITE;
 		} else {
@@ -2329,6 +2320,11 @@ next:
 	return 0;
 }
 
+static const struct mm_walk_ops migrate_vma_walk_ops = {
+	.pmd_entry		= migrate_vma_collect_pmd,
+	.pte_hole		= migrate_vma_collect_hole,
+};
+
 /*
  * migrate_vma_collect() - collect pages over a range of virtual addresses
  * @migrate: migrate struct containing all migration information
@@ -2340,21 +2336,15 @@ next:
 static void migrate_vma_collect(struct migrate_vma *migrate)
 {
 	struct mmu_notifier_range range;
-	struct mm_walk mm_walk = {
-		.pmd_entry = migrate_vma_collect_pmd,
-		.pte_hole = migrate_vma_collect_hole,
-		.vma = migrate->vma,
-		.mm = migrate->vma->vm_mm,
-		.private = migrate,
-	};
 
-	mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, NULL, mm_walk.mm,
-				migrate->start,
-				migrate->end);
+	mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, NULL,
+			migrate->vma->vm_mm, migrate->start, migrate->end);
 	mmu_notifier_invalidate_range_start(&range);
-	walk_page_range(migrate->start, migrate->end, &mm_walk);
-	mmu_notifier_invalidate_range_end(&range);
 
+	walk_page_range(migrate->vma->vm_mm, migrate->start, migrate->end,
+			&migrate_vma_walk_ops, migrate);
+
+	mmu_notifier_invalidate_range_end(&range);
 	migrate->end = migrate->start + (migrate->npages << PAGE_SHIFT);
 }
 
@@ -2577,6 +2567,110 @@ restore:
 	}
 }
 
+/**
+ * migrate_vma_setup() - prepare to migrate a range of memory
+ * @args: contains the vma, start, and and pfns arrays for the migration
+ *
+ * Returns: negative errno on failures, 0 when 0 or more pages were migrated
+ * without an error.
+ *
+ * Prepare to migrate a range of memory virtual address range by collecting all
+ * the pages backing each virtual address in the range, saving them inside the
+ * src array.  Then lock those pages and unmap them. Once the pages are locked
+ * and unmapped, check whether each page is pinned or not.  Pages that aren't
+ * pinned have the MIGRATE_PFN_MIGRATE flag set (by this function) in the
+ * corresponding src array entry.  Then restores any pages that are pinned, by
+ * remapping and unlocking those pages.
+ *
+ * The caller should then allocate destination memory and copy source memory to
+ * it for all those entries (ie with MIGRATE_PFN_VALID and MIGRATE_PFN_MIGRATE
+ * flag set).  Once these are allocated and copied, the caller must update each
+ * corresponding entry in the dst array with the pfn value of the destination
+ * page and with the MIGRATE_PFN_VALID and MIGRATE_PFN_LOCKED flags set
+ * (destination pages must have their struct pages locked, via lock_page()).
+ *
+ * Note that the caller does not have to migrate all the pages that are marked
+ * with MIGRATE_PFN_MIGRATE flag in src array unless this is a migration from
+ * device memory to system memory.  If the caller cannot migrate a device page
+ * back to system memory, then it must return VM_FAULT_SIGBUS, which has severe
+ * consequences for the userspace process, so it must be avoided if at all
+ * possible.
+ *
+ * For empty entries inside CPU page table (pte_none() or pmd_none() is true) we
+ * do set MIGRATE_PFN_MIGRATE flag inside the corresponding source array thus
+ * allowing the caller to allocate device memory for those unback virtual
+ * address.  For this the caller simply has to allocate device memory and
+ * properly set the destination entry like for regular migration.  Note that
+ * this can still fails and thus inside the device driver must check if the
+ * migration was successful for those entries after calling migrate_vma_pages()
+ * just like for regular migration.
+ *
+ * After that, the callers must call migrate_vma_pages() to go over each entry
+ * in the src array that has the MIGRATE_PFN_VALID and MIGRATE_PFN_MIGRATE flag
+ * set. If the corresponding entry in dst array has MIGRATE_PFN_VALID flag set,
+ * then migrate_vma_pages() to migrate struct page information from the source
+ * struct page to the destination struct page.  If it fails to migrate the
+ * struct page information, then it clears the MIGRATE_PFN_MIGRATE flag in the
+ * src array.
+ *
+ * At this point all successfully migrated pages have an entry in the src
+ * array with MIGRATE_PFN_VALID and MIGRATE_PFN_MIGRATE flag set and the dst
+ * array entry with MIGRATE_PFN_VALID flag set.
+ *
+ * Once migrate_vma_pages() returns the caller may inspect which pages were
+ * successfully migrated, and which were not.  Successfully migrated pages will
+ * have the MIGRATE_PFN_MIGRATE flag set for their src array entry.
+ *
+ * It is safe to update device page table after migrate_vma_pages() because
+ * both destination and source page are still locked, and the mmap_sem is held
+ * in read mode (hence no one can unmap the range being migrated).
+ *
+ * Once the caller is done cleaning up things and updating its page table (if it
+ * chose to do so, this is not an obligation) it finally calls
+ * migrate_vma_finalize() to update the CPU page table to point to new pages
+ * for successfully migrated pages or otherwise restore the CPU page table to
+ * point to the original source pages.
+ */
+int migrate_vma_setup(struct migrate_vma *args)
+{
+	long nr_pages = (args->end - args->start) >> PAGE_SHIFT;
+
+	args->start &= PAGE_MASK;
+	args->end &= PAGE_MASK;
+	if (!args->vma || is_vm_hugetlb_page(args->vma) ||
+	    (args->vma->vm_flags & VM_SPECIAL) || vma_is_dax(args->vma))
+		return -EINVAL;
+	if (nr_pages <= 0)
+		return -EINVAL;
+	if (args->start < args->vma->vm_start ||
+	    args->start >= args->vma->vm_end)
+		return -EINVAL;
+	if (args->end <= args->vma->vm_start || args->end > args->vma->vm_end)
+		return -EINVAL;
+	if (!args->src || !args->dst)
+		return -EINVAL;
+
+	memset(args->src, 0, sizeof(*args->src) * nr_pages);
+	args->cpages = 0;
+	args->npages = 0;
+
+	migrate_vma_collect(args);
+
+	if (args->cpages)
+		migrate_vma_prepare(args);
+	if (args->cpages)
+		migrate_vma_unmap(args);
+
+	/*
+	 * At this point pages are locked and unmapped, and thus they have
+	 * stable content and can safely be copied to destination memory that
+	 * is allocated by the drivers.
+	 */
+	return 0;
+
+}
+EXPORT_SYMBOL(migrate_vma_setup);
+
 static void migrate_vma_insert_page(struct migrate_vma *migrate,
 				    unsigned long addr,
 				    struct page *page,
@@ -2708,7 +2802,7 @@ abort:
 	*src &= ~MIGRATE_PFN_MIGRATE;
 }
 
-/*
+/**
  * migrate_vma_pages() - migrate meta-data from src page to dst page
  * @migrate: migrate struct containing all migration information
  *
@@ -2716,7 +2810,7 @@ abort:
  * struct page. This effectively finishes the migration from source page to the
  * destination page.
  */
-static void migrate_vma_pages(struct migrate_vma *migrate)
+void migrate_vma_pages(struct migrate_vma *migrate)
 {
 	const unsigned long npages = migrate->npages;
 	const unsigned long start = migrate->start;
@@ -2790,8 +2884,9 @@ static void migrate_vma_pages(struct migrate_vma *migrate)
 	if (notified)
 		mmu_notifier_invalidate_range_only_end(&range);
 }
+EXPORT_SYMBOL(migrate_vma_pages);
 
-/*
+/**
  * migrate_vma_finalize() - restore CPU page table entry
  * @migrate: migrate struct containing all migration information
  *
@@ -2802,7 +2897,7 @@ static void migrate_vma_pages(struct migrate_vma *migrate)
  * This also unlocks the pages and puts them back on the lru, or drops the extra
  * refcount, for device pages.
  */
-static void migrate_vma_finalize(struct migrate_vma *migrate)
+void migrate_vma_finalize(struct migrate_vma *migrate)
 {
 	const unsigned long npages = migrate->npages;
 	unsigned long i;
@@ -2845,124 +2940,5 @@ static void migrate_vma_finalize(struct migrate_vma *migrate)
 		}
 	}
 }
-
-/*
- * migrate_vma() - migrate a range of memory inside vma
- *
- * @ops: migration callback for allocating destination memory and copying
- * @vma: virtual memory area containing the range to be migrated
- * @start: start address of the range to migrate (inclusive)
- * @end: end address of the range to migrate (exclusive)
- * @src: array of hmm_pfn_t containing source pfns
- * @dst: array of hmm_pfn_t containing destination pfns
- * @private: pointer passed back to each of the callback
- * Returns: 0 on success, error code otherwise
- *
- * This function tries to migrate a range of memory virtual address range, using
- * callbacks to allocate and copy memory from source to destination. First it
- * collects all the pages backing each virtual address in the range, saving this
- * inside the src array. Then it locks those pages and unmaps them. Once the pages
- * are locked and unmapped, it checks whether each page is pinned or not. Pages
- * that aren't pinned have the MIGRATE_PFN_MIGRATE flag set (by this function)
- * in the corresponding src array entry. It then restores any pages that are
- * pinned, by remapping and unlocking those pages.
- *
- * At this point it calls the alloc_and_copy() callback. For documentation on
- * what is expected from that callback, see struct migrate_vma_ops comments in
- * include/linux/migrate.h
- *
- * After the alloc_and_copy() callback, this function goes over each entry in
- * the src array that has the MIGRATE_PFN_VALID and MIGRATE_PFN_MIGRATE flag
- * set. If the corresponding entry in dst array has MIGRATE_PFN_VALID flag set,
- * then the function tries to migrate struct page information from the source
- * struct page to the destination struct page. If it fails to migrate the struct
- * page information, then it clears the MIGRATE_PFN_MIGRATE flag in the src
- * array.
- *
- * At this point all successfully migrated pages have an entry in the src
- * array with MIGRATE_PFN_VALID and MIGRATE_PFN_MIGRATE flag set and the dst
- * array entry with MIGRATE_PFN_VALID flag set.
- *
- * It then calls the finalize_and_map() callback. See comments for "struct
- * migrate_vma_ops", in include/linux/migrate.h for details about
- * finalize_and_map() behavior.
- *
- * After the finalize_and_map() callback, for successfully migrated pages, this
- * function updates the CPU page table to point to new pages, otherwise it
- * restores the CPU page table to point to the original source pages.
- *
- * Function returns 0 after the above steps, even if no pages were migrated
- * (The function only returns an error if any of the arguments are invalid.)
- *
- * Both src and dst array must be big enough for (end - start) >> PAGE_SHIFT
- * unsigned long entries.
- */
-int migrate_vma(const struct migrate_vma_ops *ops,
-		struct vm_area_struct *vma,
-		unsigned long start,
-		unsigned long end,
-		unsigned long *src,
-		unsigned long *dst,
-		void *private)
-{
-	struct migrate_vma migrate;
-
-	/* Sanity check the arguments */
-	start &= PAGE_MASK;
-	end &= PAGE_MASK;
-	if (!vma || is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
-			vma_is_dax(vma))
-		return -EINVAL;
-	if (start < vma->vm_start || start >= vma->vm_end)
-		return -EINVAL;
-	if (end <= vma->vm_start || end > vma->vm_end)
-		return -EINVAL;
-	if (!ops || !src || !dst || start >= end)
-		return -EINVAL;
-
-	memset(src, 0, sizeof(*src) * ((end - start) >> PAGE_SHIFT));
-	migrate.src = src;
-	migrate.dst = dst;
-	migrate.start = start;
-	migrate.npages = 0;
-	migrate.cpages = 0;
-	migrate.end = end;
-	migrate.vma = vma;
-
-	/* Collect, and try to unmap source pages */
-	migrate_vma_collect(&migrate);
-	if (!migrate.cpages)
-		return 0;
-
-	/* Lock and isolate page */
-	migrate_vma_prepare(&migrate);
-	if (!migrate.cpages)
-		return 0;
-
-	/* Unmap pages */
-	migrate_vma_unmap(&migrate);
-	if (!migrate.cpages)
-		return 0;
-
-	/*
-	 * At this point pages are locked and unmapped, and thus they have
-	 * stable content and can safely be copied to destination memory that
-	 * is allocated by the callback.
-	 *
-	 * Note that migration can fail in migrate_vma_struct_page() for each
-	 * individual page.
-	 */
-	ops->alloc_and_copy(vma, src, dst, start, end, private);
-
-	/* This does the real migration of struct page */
-	migrate_vma_pages(&migrate);
-
-	ops->finalize_and_map(vma, src, dst, start, end, private);
-
-	/* Unlock and remap pages */
-	migrate_vma_finalize(&migrate);
-
-	return 0;
-}
-EXPORT_SYMBOL(migrate_vma);
-#endif /* defined(MIGRATE_VMA_HELPER) */
+EXPORT_SYMBOL(migrate_vma_finalize);
+#endif /* CONFIG_DEVICE_PRIVATE */
diff --git a/mm/mincore.c b/mm/mincore.c
index 4fe91d497436..f9a9dbe8cd33 100644
--- a/mm/mincore.c
+++ b/mm/mincore.c
@@ -10,7 +10,7 @@
  */
 #include <linux/pagemap.h>
 #include <linux/gfp.h>
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/mman.h>
 #include <linux/syscalls.h>
 #include <linux/swap.h>
@@ -193,6 +193,12 @@ static inline bool can_do_mincore(struct vm_area_struct *vma)
 		inode_permission(file_inode(vma->vm_file), MAY_WRITE) == 0;
 }
 
+static const struct mm_walk_ops mincore_walk_ops = {
+	.pmd_entry		= mincore_pte_range,
+	.pte_hole		= mincore_unmapped_range,
+	.hugetlb_entry		= mincore_hugetlb,
+};
+
 /*
  * Do a chunk of "sys_mincore()". We've already checked
  * all the arguments, we hold the mmap semaphore: we should
@@ -203,12 +209,6 @@ static long do_mincore(unsigned long addr, unsigned long pages, unsigned char *v
 	struct vm_area_struct *vma;
 	unsigned long end;
 	int err;
-	struct mm_walk mincore_walk = {
-		.pmd_entry = mincore_pte_range,
-		.pte_hole = mincore_unmapped_range,
-		.hugetlb_entry = mincore_hugetlb,
-		.private = vec,
-	};
 
 	vma = find_vma(current->mm, addr);
 	if (!vma || addr < vma->vm_start)
@@ -219,8 +219,7 @@ static long do_mincore(unsigned long addr, unsigned long pages, unsigned char *v
 		memset(vec, 1, pages);
 		return pages;
 	}
-	mincore_walk.mm = vma->vm_mm;
-	err = walk_page_range(addr, end, &mincore_walk);
+	err = walk_page_range(vma->vm_mm, addr, end, &mincore_walk_ops, vec);
 	if (err < 0)
 		return err;
 	return (end - addr) >> PAGE_SHIFT;
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index b5670620aea0..7fde88695f35 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -21,17 +21,11 @@
 /* global SRCU for all MMs */
 DEFINE_STATIC_SRCU(srcu);
 
-/*
- * This function allows mmu_notifier::release callback to delay a call to
- * a function that will free appropriate resources. The function must be
- * quick and must not block.
- */
-void mmu_notifier_call_srcu(struct rcu_head *rcu,
-			    void (*func)(struct rcu_head *rcu))
-{
-	call_srcu(&srcu, rcu, func);
-}
-EXPORT_SYMBOL_GPL(mmu_notifier_call_srcu);
+#ifdef CONFIG_LOCKDEP
+struct lockdep_map __mmu_notifier_invalidate_range_start_map = {
+	.name = "mmu_notifier_invalidate_range_start"
+};
+#endif
 
 /*
  * This function can't run concurrently against mmu_notifier_register
@@ -174,11 +168,19 @@ int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range)
 	id = srcu_read_lock(&srcu);
 	hlist_for_each_entry_rcu(mn, &range->mm->mmu_notifier_mm->list, hlist) {
 		if (mn->ops->invalidate_range_start) {
-			int _ret = mn->ops->invalidate_range_start(mn, range);
+			int _ret;
+
+			if (!mmu_notifier_range_blockable(range))
+				non_block_start();
+			_ret = mn->ops->invalidate_range_start(mn, range);
+			if (!mmu_notifier_range_blockable(range))
+				non_block_end();
 			if (_ret) {
 				pr_info("%pS callback failed with %d in %sblockable context.\n",
 					mn->ops->invalidate_range_start, _ret,
 					!mmu_notifier_range_blockable(range) ? "non-" : "");
+				WARN_ON(mmu_notifier_range_blockable(range) ||
+					ret != -EAGAIN);
 				ret = _ret;
 			}
 		}
@@ -187,7 +189,6 @@ int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range)
 
 	return ret;
 }
-EXPORT_SYMBOL_GPL(__mmu_notifier_invalidate_range_start);
 
 void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
 					 bool only_end)
@@ -195,6 +196,7 @@ void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
 	struct mmu_notifier *mn;
 	int id;
 
+	lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
 	id = srcu_read_lock(&srcu);
 	hlist_for_each_entry_rcu(mn, &range->mm->mmu_notifier_mm->list, hlist) {
 		/*
@@ -214,12 +216,17 @@ void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
 			mn->ops->invalidate_range(mn, range->mm,
 						  range->start,
 						  range->end);
-		if (mn->ops->invalidate_range_end)
+		if (mn->ops->invalidate_range_end) {
+			if (!mmu_notifier_range_blockable(range))
+				non_block_start();
 			mn->ops->invalidate_range_end(mn, range);
+			if (!mmu_notifier_range_blockable(range))
+				non_block_end();
+		}
 	}
 	srcu_read_unlock(&srcu, id);
+	lock_map_release(&__mmu_notifier_invalidate_range_start_map);
 }
-EXPORT_SYMBOL_GPL(__mmu_notifier_invalidate_range_end);
 
 void __mmu_notifier_invalidate_range(struct mm_struct *mm,
 				  unsigned long start, unsigned long end)
@@ -234,35 +241,49 @@ void __mmu_notifier_invalidate_range(struct mm_struct *mm,
 	}
 	srcu_read_unlock(&srcu, id);
 }
-EXPORT_SYMBOL_GPL(__mmu_notifier_invalidate_range);
 
-static int do_mmu_notifier_register(struct mmu_notifier *mn,
-				    struct mm_struct *mm,
-				    int take_mmap_sem)
+/*
+ * Same as mmu_notifier_register but here the caller must hold the
+ * mmap_sem in write mode.
+ */
+int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 {
-	struct mmu_notifier_mm *mmu_notifier_mm;
+	struct mmu_notifier_mm *mmu_notifier_mm = NULL;
 	int ret;
 
+	lockdep_assert_held_write(&mm->mmap_sem);
 	BUG_ON(atomic_read(&mm->mm_users) <= 0);
 
-	ret = -ENOMEM;
-	mmu_notifier_mm = kmalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL);
-	if (unlikely(!mmu_notifier_mm))
-		goto out;
+	if (IS_ENABLED(CONFIG_LOCKDEP)) {
+		fs_reclaim_acquire(GFP_KERNEL);
+		lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
+		lock_map_release(&__mmu_notifier_invalidate_range_start_map);
+		fs_reclaim_release(GFP_KERNEL);
+	}
 
-	if (take_mmap_sem)
-		down_write(&mm->mmap_sem);
-	ret = mm_take_all_locks(mm);
-	if (unlikely(ret))
-		goto out_clean;
+	mn->mm = mm;
+	mn->users = 1;
+
+	if (!mm->mmu_notifier_mm) {
+		/*
+		 * kmalloc cannot be called under mm_take_all_locks(), but we
+		 * know that mm->mmu_notifier_mm can't change while we hold
+		 * the write side of the mmap_sem.
+		 */
+		mmu_notifier_mm =
+			kmalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL);
+		if (!mmu_notifier_mm)
+			return -ENOMEM;
 
-	if (!mm_has_notifiers(mm)) {
 		INIT_HLIST_HEAD(&mmu_notifier_mm->list);
 		spin_lock_init(&mmu_notifier_mm->lock);
-
-		mm->mmu_notifier_mm = mmu_notifier_mm;
-		mmu_notifier_mm = NULL;
 	}
+
+	ret = mm_take_all_locks(mm);
+	if (unlikely(ret))
+		goto out_clean;
+
+	/* Pairs with the mmdrop in mmu_notifier_unregister_* */
 	mmgrab(mm);
 
 	/*
@@ -273,48 +294,118 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn,
 	 * We can't race against any other mmu notifier method either
 	 * thanks to mm_take_all_locks().
 	 */
+	if (mmu_notifier_mm)
+		mm->mmu_notifier_mm = mmu_notifier_mm;
+
 	spin_lock(&mm->mmu_notifier_mm->lock);
 	hlist_add_head_rcu(&mn->hlist, &mm->mmu_notifier_mm->list);
 	spin_unlock(&mm->mmu_notifier_mm->lock);
 
 	mm_drop_all_locks(mm);
+	BUG_ON(atomic_read(&mm->mm_users) <= 0);
+	return 0;
+
 out_clean:
-	if (take_mmap_sem)
-		up_write(&mm->mmap_sem);
 	kfree(mmu_notifier_mm);
-out:
-	BUG_ON(atomic_read(&mm->mm_users) <= 0);
 	return ret;
 }
+EXPORT_SYMBOL_GPL(__mmu_notifier_register);
 
-/*
+/**
+ * mmu_notifier_register - Register a notifier on a mm
+ * @mn: The notifier to attach
+ * @mm: The mm to attach the notifier to
+ *
  * Must not hold mmap_sem nor any other VM related lock when calling
  * this registration function. Must also ensure mm_users can't go down
  * to zero while this runs to avoid races with mmu_notifier_release,
  * so mm has to be current->mm or the mm should be pinned safely such
  * as with get_task_mm(). If the mm is not current->mm, the mm_users
  * pin should be released by calling mmput after mmu_notifier_register
- * returns. mmu_notifier_unregister must be always called to
- * unregister the notifier. mm_count is automatically pinned to allow
- * mmu_notifier_unregister to safely run at any time later, before or
- * after exit_mmap. ->release will always be called before exit_mmap
- * frees the pages.
+ * returns.
+ *
+ * mmu_notifier_unregister() or mmu_notifier_put() must be always called to
+ * unregister the notifier.
+ *
+ * While the caller has a mmu_notifier get the mn->mm pointer will remain
+ * valid, and can be converted to an active mm pointer via mmget_not_zero().
  */
 int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 {
-	return do_mmu_notifier_register(mn, mm, 1);
+	int ret;
+
+	down_write(&mm->mmap_sem);
+	ret = __mmu_notifier_register(mn, mm);
+	up_write(&mm->mmap_sem);
+	return ret;
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_register);
 
-/*
- * Same as mmu_notifier_register but here the caller must hold the
- * mmap_sem in write mode.
+static struct mmu_notifier *
+find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
+{
+	struct mmu_notifier *mn;
+
+	spin_lock(&mm->mmu_notifier_mm->lock);
+	hlist_for_each_entry_rcu (mn, &mm->mmu_notifier_mm->list, hlist) {
+		if (mn->ops != ops)
+			continue;
+
+		if (likely(mn->users != UINT_MAX))
+			mn->users++;
+		else
+			mn = ERR_PTR(-EOVERFLOW);
+		spin_unlock(&mm->mmu_notifier_mm->lock);
+		return mn;
+	}
+	spin_unlock(&mm->mmu_notifier_mm->lock);
+	return NULL;
+}
+
+/**
+ * mmu_notifier_get_locked - Return the single struct mmu_notifier for
+ *                           the mm & ops
+ * @ops: The operations struct being subscribe with
+ * @mm : The mm to attach notifiers too
+ *
+ * This function either allocates a new mmu_notifier via
+ * ops->alloc_notifier(), or returns an already existing notifier on the
+ * list. The value of the ops pointer is used to determine when two notifiers
+ * are the same.
+ *
+ * Each call to mmu_notifier_get() must be paired with a call to
+ * mmu_notifier_put(). The caller must hold the write side of mm->mmap_sem.
+ *
+ * While the caller has a mmu_notifier get the mm pointer will remain valid,
+ * and can be converted to an active mm pointer via mmget_not_zero().
  */
-int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
+struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
+					     struct mm_struct *mm)
 {
-	return do_mmu_notifier_register(mn, mm, 0);
+	struct mmu_notifier *mn;
+	int ret;
+
+	lockdep_assert_held_write(&mm->mmap_sem);
+
+	if (mm->mmu_notifier_mm) {
+		mn = find_get_mmu_notifier(mm, ops);
+		if (mn)
+			return mn;
+	}
+
+	mn = ops->alloc_notifier(mm);
+	if (IS_ERR(mn))
+		return mn;
+	mn->ops = ops;
+	ret = __mmu_notifier_register(mn, mm);
+	if (ret)
+		goto out_free;
+	return mn;
+out_free:
+	mn->ops->free_notifier(mn);
+	return ERR_PTR(ret);
 }
-EXPORT_SYMBOL_GPL(__mmu_notifier_register);
+EXPORT_SYMBOL_GPL(mmu_notifier_get_locked);
 
 /* this is called after the last mmu_notifier_unregister() returned */
 void __mmu_notifier_mm_destroy(struct mm_struct *mm)
@@ -375,24 +466,74 @@ void mmu_notifier_unregister(struct mmu_notifier *mn, struct mm_struct *mm)
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_unregister);
 
-/*
- * Same as mmu_notifier_unregister but no callback and no srcu synchronization.
+static void mmu_notifier_free_rcu(struct rcu_head *rcu)
+{
+	struct mmu_notifier *mn = container_of(rcu, struct mmu_notifier, rcu);
+	struct mm_struct *mm = mn->mm;
+
+	mn->ops->free_notifier(mn);
+	/* Pairs with the get in __mmu_notifier_register() */
+	mmdrop(mm);
+}
+
+/**
+ * mmu_notifier_put - Release the reference on the notifier
+ * @mn: The notifier to act on
+ *
+ * This function must be paired with each mmu_notifier_get(), it releases the
+ * reference obtained by the get. If this is the last reference then process
+ * to free the notifier will be run asynchronously.
+ *
+ * Unlike mmu_notifier_unregister() the get/put flow only calls ops->release
+ * when the mm_struct is destroyed. Instead free_notifier is always called to
+ * release any resources held by the user.
+ *
+ * As ops->release is not guaranteed to be called, the user must ensure that
+ * all sptes are dropped, and no new sptes can be established before
+ * mmu_notifier_put() is called.
+ *
+ * This function can be called from the ops->release callback, however the
+ * caller must still ensure it is called pairwise with mmu_notifier_get().
+ *
+ * Modules calling this function must call mmu_notifier_synchronize() in
+ * their __exit functions to ensure the async work is completed.
  */
-void mmu_notifier_unregister_no_release(struct mmu_notifier *mn,
-					struct mm_struct *mm)
+void mmu_notifier_put(struct mmu_notifier *mn)
 {
+	struct mm_struct *mm = mn->mm;
+
 	spin_lock(&mm->mmu_notifier_mm->lock);
-	/*
-	 * Can not use list_del_rcu() since __mmu_notifier_release
-	 * can delete it before we hold the lock.
-	 */
+	if (WARN_ON(!mn->users) || --mn->users)
+		goto out_unlock;
 	hlist_del_init_rcu(&mn->hlist);
 	spin_unlock(&mm->mmu_notifier_mm->lock);
 
-	BUG_ON(atomic_read(&mm->mm_count) <= 0);
-	mmdrop(mm);
+	call_srcu(&srcu, &mn->rcu, mmu_notifier_free_rcu);
+	return;
+
+out_unlock:
+	spin_unlock(&mm->mmu_notifier_mm->lock);
+}
+EXPORT_SYMBOL_GPL(mmu_notifier_put);
+
+/**
+ * mmu_notifier_synchronize - Ensure all mmu_notifiers are freed
+ *
+ * This function ensures that all outstanding async SRU work from
+ * mmu_notifier_put() is completed. After it returns any mmu_notifier_ops
+ * associated with an unused mmu_notifier will no longer be called.
+ *
+ * Before using the caller must ensure that all of its mmu_notifiers have been
+ * fully released via mmu_notifier_put().
+ *
+ * Modules using the mmu_notifier_put() API should call this in their __exit
+ * function to avoid module unloading races.
+ */
+void mmu_notifier_synchronize(void)
+{
+	synchronize_srcu(&srcu);
 }
-EXPORT_SYMBOL_GPL(mmu_notifier_unregister_no_release);
+EXPORT_SYMBOL_GPL(mmu_notifier_synchronize);
 
 bool
 mmu_notifier_range_update_to_read_only(const struct mmu_notifier_range *range)
diff --git a/mm/mprotect.c b/mm/mprotect.c
index bf38dfbbb4b4..675e5d34a507 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -9,7 +9,7 @@
  *  (C) Copyright 2002 Red Hat Inc, All Rights Reserved
  */
 
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/hugetlb.h>
 #include <linux/shm.h>
 #include <linux/mman.h>
@@ -329,20 +329,11 @@ static int prot_none_test(unsigned long addr, unsigned long next,
 	return 0;
 }
 
-static int prot_none_walk(struct vm_area_struct *vma, unsigned long start,
-			   unsigned long end, unsigned long newflags)
-{
-	pgprot_t new_pgprot = vm_get_page_prot(newflags);
-	struct mm_walk prot_none_walk = {
-		.pte_entry = prot_none_pte_entry,
-		.hugetlb_entry = prot_none_hugetlb_entry,
-		.test_walk = prot_none_test,
-		.mm = current->mm,
-		.private = &new_pgprot,
-	};
-
-	return walk_page_range(start, end, &prot_none_walk);
-}
+static const struct mm_walk_ops prot_none_walk_ops = {
+	.pte_entry		= prot_none_pte_entry,
+	.hugetlb_entry		= prot_none_hugetlb_entry,
+	.test_walk		= prot_none_test,
+};
 
 int
 mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
@@ -369,7 +360,10 @@ mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
 	if (arch_has_pfn_modify_check() &&
 	    (vma->vm_flags & (VM_PFNMAP|VM_MIXEDMAP)) &&
 	    (newflags & (VM_READ|VM_WRITE|VM_EXEC)) == 0) {
-		error = prot_none_walk(vma, start, end, newflags);
+		pgprot_t new_pgprot = vm_get_page_prot(newflags);
+
+		error = walk_page_range(current->mm, start, end,
+				&prot_none_walk_ops, &new_pgprot);
 		if (error)
 			return error;
 	}
diff --git a/mm/page_alloc.c b/mm/page_alloc.c
index 6991ccec9c32..ff5484fdbdf9 100644
--- a/mm/page_alloc.c
+++ b/mm/page_alloc.c
@@ -5971,7 +5971,7 @@ void __ref memmap_init_zone_device(struct zone *zone,
 		}
 	}
 
-	pr_info("%s initialised, %lu pages in %ums\n", dev_name(pgmap->dev),
+	pr_info("%s initialised %lu pages in %ums\n", __func__,
 		size, jiffies_to_msecs(jiffies - start));
 }
 
diff --git a/mm/pagewalk.c b/mm/pagewalk.c
index c3084ff2569d..d48c2a986ea3 100644
--- a/mm/pagewalk.c
+++ b/mm/pagewalk.c
@@ -1,5 +1,5 @@
 // SPDX-License-Identifier: GPL-2.0
-#include <linux/mm.h>
+#include <linux/pagewalk.h>
 #include <linux/highmem.h>
 #include <linux/sched.h>
 #include <linux/hugetlb.h>
@@ -9,10 +9,11 @@ static int walk_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
 {
 	pte_t *pte;
 	int err = 0;
+	const struct mm_walk_ops *ops = walk->ops;
 
 	pte = pte_offset_map(pmd, addr);
 	for (;;) {
-		err = walk->pte_entry(pte, addr, addr + PAGE_SIZE, walk);
+		err = ops->pte_entry(pte, addr, addr + PAGE_SIZE, walk);
 		if (err)
 		       break;
 		addr += PAGE_SIZE;
@@ -30,6 +31,7 @@ static int walk_pmd_range(pud_t *pud, unsigned long addr, unsigned long end,
 {
 	pmd_t *pmd;
 	unsigned long next;
+	const struct mm_walk_ops *ops = walk->ops;
 	int err = 0;
 
 	pmd = pmd_offset(pud, addr);
@@ -37,8 +39,8 @@ static int walk_pmd_range(pud_t *pud, unsigned long addr, unsigned long end,
 again:
 		next = pmd_addr_end(addr, end);
 		if (pmd_none(*pmd) || !walk->vma) {
-			if (walk->pte_hole)
-				err = walk->pte_hole(addr, next, walk);
+			if (ops->pte_hole)
+				err = ops->pte_hole(addr, next, walk);
 			if (err)
 				break;
 			continue;
@@ -47,8 +49,8 @@ again:
 		 * This implies that each ->pmd_entry() handler
 		 * needs to know about pmd_trans_huge() pmds
 		 */
-		if (walk->pmd_entry)
-			err = walk->pmd_entry(pmd, addr, next, walk);
+		if (ops->pmd_entry)
+			err = ops->pmd_entry(pmd, addr, next, walk);
 		if (err)
 			break;
 
@@ -56,7 +58,7 @@ again:
 		 * Check this here so we only break down trans_huge
 		 * pages when we _need_ to
 		 */
-		if (!walk->pte_entry)
+		if (!ops->pte_entry)
 			continue;
 
 		split_huge_pmd(walk->vma, pmd, addr);
@@ -75,6 +77,7 @@ static int walk_pud_range(p4d_t *p4d, unsigned long addr, unsigned long end,
 {
 	pud_t *pud;
 	unsigned long next;
+	const struct mm_walk_ops *ops = walk->ops;
 	int err = 0;
 
 	pud = pud_offset(p4d, addr);
@@ -82,18 +85,18 @@ static int walk_pud_range(p4d_t *p4d, unsigned long addr, unsigned long end,
  again:
 		next = pud_addr_end(addr, end);
 		if (pud_none(*pud) || !walk->vma) {
-			if (walk->pte_hole)
-				err = walk->pte_hole(addr, next, walk);
+			if (ops->pte_hole)
+				err = ops->pte_hole(addr, next, walk);
 			if (err)
 				break;
 			continue;
 		}
 
-		if (walk->pud_entry) {
+		if (ops->pud_entry) {
 			spinlock_t *ptl = pud_trans_huge_lock(pud, walk->vma);
 
 			if (ptl) {
-				err = walk->pud_entry(pud, addr, next, walk);
+				err = ops->pud_entry(pud, addr, next, walk);
 				spin_unlock(ptl);
 				if (err)
 					break;
@@ -105,7 +108,7 @@ static int walk_pud_range(p4d_t *p4d, unsigned long addr, unsigned long end,
 		if (pud_none(*pud))
 			goto again;
 
-		if (walk->pmd_entry || walk->pte_entry)
+		if (ops->pmd_entry || ops->pte_entry)
 			err = walk_pmd_range(pud, addr, next, walk);
 		if (err)
 			break;
@@ -119,19 +122,20 @@ static int walk_p4d_range(pgd_t *pgd, unsigned long addr, unsigned long end,
 {
 	p4d_t *p4d;
 	unsigned long next;
+	const struct mm_walk_ops *ops = walk->ops;
 	int err = 0;
 
 	p4d = p4d_offset(pgd, addr);
 	do {
 		next = p4d_addr_end(addr, end);
 		if (p4d_none_or_clear_bad(p4d)) {
-			if (walk->pte_hole)
-				err = walk->pte_hole(addr, next, walk);
+			if (ops->pte_hole)
+				err = ops->pte_hole(addr, next, walk);
 			if (err)
 				break;
 			continue;
 		}
-		if (walk->pmd_entry || walk->pte_entry)
+		if (ops->pmd_entry || ops->pte_entry)
 			err = walk_pud_range(p4d, addr, next, walk);
 		if (err)
 			break;
@@ -145,19 +149,20 @@ static int walk_pgd_range(unsigned long addr, unsigned long end,
 {
 	pgd_t *pgd;
 	unsigned long next;
+	const struct mm_walk_ops *ops = walk->ops;
 	int err = 0;
 
 	pgd = pgd_offset(walk->mm, addr);
 	do {
 		next = pgd_addr_end(addr, end);
 		if (pgd_none_or_clear_bad(pgd)) {
-			if (walk->pte_hole)
-				err = walk->pte_hole(addr, next, walk);
+			if (ops->pte_hole)
+				err = ops->pte_hole(addr, next, walk);
 			if (err)
 				break;
 			continue;
 		}
-		if (walk->pmd_entry || walk->pte_entry)
+		if (ops->pmd_entry || ops->pte_entry)
 			err = walk_p4d_range(pgd, addr, next, walk);
 		if (err)
 			break;
@@ -183,6 +188,7 @@ static int walk_hugetlb_range(unsigned long addr, unsigned long end,
 	unsigned long hmask = huge_page_mask(h);
 	unsigned long sz = huge_page_size(h);
 	pte_t *pte;
+	const struct mm_walk_ops *ops = walk->ops;
 	int err = 0;
 
 	do {
@@ -190,9 +196,9 @@ static int walk_hugetlb_range(unsigned long addr, unsigned long end,
 		pte = huge_pte_offset(walk->mm, addr & hmask, sz);
 
 		if (pte)
-			err = walk->hugetlb_entry(pte, hmask, addr, next, walk);
-		else if (walk->pte_hole)
-			err = walk->pte_hole(addr, next, walk);
+			err = ops->hugetlb_entry(pte, hmask, addr, next, walk);
+		else if (ops->pte_hole)
+			err = ops->pte_hole(addr, next, walk);
 
 		if (err)
 			break;
@@ -220,9 +226,10 @@ static int walk_page_test(unsigned long start, unsigned long end,
 			struct mm_walk *walk)
 {
 	struct vm_area_struct *vma = walk->vma;
+	const struct mm_walk_ops *ops = walk->ops;
 
-	if (walk->test_walk)
-		return walk->test_walk(start, end, walk);
+	if (ops->test_walk)
+		return ops->test_walk(start, end, walk);
 
 	/*
 	 * vma(VM_PFNMAP) doesn't have any valid struct pages behind VM_PFNMAP
@@ -234,8 +241,8 @@ static int walk_page_test(unsigned long start, unsigned long end,
 	 */
 	if (vma->vm_flags & VM_PFNMAP) {
 		int err = 1;
-		if (walk->pte_hole)
-			err = walk->pte_hole(start, end, walk);
+		if (ops->pte_hole)
+			err = ops->pte_hole(start, end, walk);
 		return err ? err : 1;
 	}
 	return 0;
@@ -248,7 +255,7 @@ static int __walk_page_range(unsigned long start, unsigned long end,
 	struct vm_area_struct *vma = walk->vma;
 
 	if (vma && is_vm_hugetlb_page(vma)) {
-		if (walk->hugetlb_entry)
+		if (walk->ops->hugetlb_entry)
 			err = walk_hugetlb_range(start, end, walk);
 	} else
 		err = walk_pgd_range(start, end, walk);
@@ -258,11 +265,13 @@ static int __walk_page_range(unsigned long start, unsigned long end,
 
 /**
  * walk_page_range - walk page table with caller specific callbacks
- * @start: start address of the virtual address range
- * @end: end address of the virtual address range
- * @walk: mm_walk structure defining the callbacks and the target address space
+ * @mm:		mm_struct representing the target process of page table walk
+ * @start:	start address of the virtual address range
+ * @end:	end address of the virtual address range
+ * @ops:	operation to call during the walk
+ * @private:	private data for callbacks' usage
  *
- * Recursively walk the page table tree of the process represented by @walk->mm
+ * Recursively walk the page table tree of the process represented by @mm
  * within the virtual address range [@start, @end). During walking, we can do
  * some caller-specific works for each entry, by setting up pmd_entry(),
  * pte_entry(), and/or hugetlb_entry(). If you don't set up for some of these
@@ -278,47 +287,52 @@ static int __walk_page_range(unsigned long start, unsigned long end,
  *
  * Before starting to walk page table, some callers want to check whether
  * they really want to walk over the current vma, typically by checking
- * its vm_flags. walk_page_test() and @walk->test_walk() are used for this
+ * its vm_flags. walk_page_test() and @ops->test_walk() are used for this
  * purpose.
  *
  * struct mm_walk keeps current values of some common data like vma and pmd,
  * which are useful for the access from callbacks. If you want to pass some
- * caller-specific data to callbacks, @walk->private should be helpful.
+ * caller-specific data to callbacks, @private should be helpful.
  *
  * Locking:
- *   Callers of walk_page_range() and walk_page_vma() should hold
- *   @walk->mm->mmap_sem, because these function traverse vma list and/or
- *   access to vma's data.
+ *   Callers of walk_page_range() and walk_page_vma() should hold @mm->mmap_sem,
+ *   because these function traverse vma list and/or access to vma's data.
  */
-int walk_page_range(unsigned long start, unsigned long end,
-		    struct mm_walk *walk)
+int walk_page_range(struct mm_struct *mm, unsigned long start,
+		unsigned long end, const struct mm_walk_ops *ops,
+		void *private)
 {
 	int err = 0;
 	unsigned long next;
 	struct vm_area_struct *vma;
+	struct mm_walk walk = {
+		.ops		= ops,
+		.mm		= mm,
+		.private	= private,
+	};
 
 	if (start >= end)
 		return -EINVAL;
 
-	if (!walk->mm)
+	if (!walk.mm)
 		return -EINVAL;
 
-	VM_BUG_ON_MM(!rwsem_is_locked(&walk->mm->mmap_sem), walk->mm);
+	lockdep_assert_held(&walk.mm->mmap_sem);
 
-	vma = find_vma(walk->mm, start);
+	vma = find_vma(walk.mm, start);
 	do {
 		if (!vma) { /* after the last vma */
-			walk->vma = NULL;
+			walk.vma = NULL;
 			next = end;
 		} else if (start < vma->vm_start) { /* outside vma */
-			walk->vma = NULL;
+			walk.vma = NULL;
 			next = min(end, vma->vm_start);
 		} else { /* inside vma */
-			walk->vma = vma;
+			walk.vma = vma;
 			next = min(end, vma->vm_end);
 			vma = vma->vm_next;
 
-			err = walk_page_test(start, next, walk);
+			err = walk_page_test(start, next, &walk);
 			if (err > 0) {
 				/*
 				 * positive return values are purely for
@@ -331,28 +345,34 @@ int walk_page_range(unsigned long start, unsigned long end,
 			if (err < 0)
 				break;
 		}
-		if (walk->vma || walk->pte_hole)
-			err = __walk_page_range(start, next, walk);
+		if (walk.vma || walk.ops->pte_hole)
+			err = __walk_page_range(start, next, &walk);
 		if (err)
 			break;
 	} while (start = next, start < end);
 	return err;
 }
 
-int walk_page_vma(struct vm_area_struct *vma, struct mm_walk *walk)
+int walk_page_vma(struct vm_area_struct *vma, const struct mm_walk_ops *ops,
+		void *private)
 {
+	struct mm_walk walk = {
+		.ops		= ops,
+		.mm		= vma->vm_mm,
+		.vma		= vma,
+		.private	= private,
+	};
 	int err;
 
-	if (!walk->mm)
+	if (!walk.mm)
 		return -EINVAL;
 
-	VM_BUG_ON(!rwsem_is_locked(&walk->mm->mmap_sem));
-	VM_BUG_ON(!vma);
-	walk->vma = vma;
-	err = walk_page_test(vma->vm_start, vma->vm_end, walk);
+	lockdep_assert_held(&walk.mm->mmap_sem);
+
+	err = walk_page_test(vma->vm_start, vma->vm_end, &walk);
 	if (err > 0)
 		return 0;
 	if (err < 0)
 		return err;
-	return __walk_page_range(vma->vm_start, vma->vm_end, walk);
+	return __walk_page_range(vma->vm_start, vma->vm_end, &walk);
 }
diff --git a/tools/testing/nvdimm/test/iomap.c b/tools/testing/nvdimm/test/iomap.c
index cd040b5abffe..3f55f2f99112 100644
--- a/tools/testing/nvdimm/test/iomap.c
+++ b/tools/testing/nvdimm/test/iomap.c
@@ -132,7 +132,6 @@ void *__wrap_devm_memremap_pages(struct device *dev, struct dev_pagemap *pgmap)
 	if (!nfit_res)
 		return devm_memremap_pages(dev, pgmap);
 
-	pgmap->dev = dev;
 	if (!pgmap->ref) {
 		if (pgmap->ops && (pgmap->ops->kill || pgmap->ops->cleanup))
 			return ERR_PTR(-EINVAL);