summary refs log tree commit diff
path: root/mm
diff options
context:
space:
mode:
Diffstat (limited to 'mm')
-rw-r--r--mm/hmm.c23
-rw-r--r--mm/mmu_notifier.c21
2 files changed, 28 insertions, 16 deletions
diff --git a/mm/hmm.c b/mm/hmm.c
index 361f3706962f..789587731217 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -189,35 +189,30 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 }
 
 static int hmm_invalidate_range_start(struct mmu_notifier *mn,
-				      struct mm_struct *mm,
-				      unsigned long start,
-				      unsigned long end,
-				      bool blockable)
+			const struct mmu_notifier_range *range)
 {
 	struct hmm_update update;
-	struct hmm *hmm = mm->hmm;
+	struct hmm *hmm = range->mm->hmm;
 
 	VM_BUG_ON(!hmm);
 
-	update.start = start;
-	update.end = end;
+	update.start = range->start;
+	update.end = range->end;
 	update.event = HMM_UPDATE_INVALIDATE;
-	update.blockable = blockable;
+	update.blockable = range->blockable;
 	return hmm_invalidate_range(hmm, true, &update);
 }
 
 static void hmm_invalidate_range_end(struct mmu_notifier *mn,
-				     struct mm_struct *mm,
-				     unsigned long start,
-				     unsigned long end)
+			const struct mmu_notifier_range *range)
 {
 	struct hmm_update update;
-	struct hmm *hmm = mm->hmm;
+	struct hmm *hmm = range->mm->hmm;
 
 	VM_BUG_ON(!hmm);
 
-	update.start = start;
-	update.end = end;
+	update.start = range->start;
+	update.end = range->end;
 	update.event = HMM_UPDATE_INVALIDATE;
 	update.blockable = true;
 	hmm_invalidate_range(hmm, false, &update);
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index 755466cd289a..74a7dc3d11c8 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -171,14 +171,20 @@ int __mmu_notifier_invalidate_range_start(struct mm_struct *mm,
 				  unsigned long start, unsigned long end,
 				  bool blockable)
 {
+	struct mmu_notifier_range _range, *range = &_range;
 	struct mmu_notifier *mn;
 	int ret = 0;
 	int id;
 
+	range->blockable = blockable;
+	range->start = start;
+	range->end = end;
+	range->mm = mm;
+
 	id = srcu_read_lock(&srcu);
 	hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) {
 		if (mn->ops->invalidate_range_start) {
-			int _ret = mn->ops->invalidate_range_start(mn, mm, start, end, blockable);
+			int _ret = mn->ops->invalidate_range_start(mn, range);
 			if (_ret) {
 				pr_info("%pS callback failed with %d in %sblockable context.\n",
 						mn->ops->invalidate_range_start, _ret,
@@ -198,9 +204,20 @@ void __mmu_notifier_invalidate_range_end(struct mm_struct *mm,
 					 unsigned long end,
 					 bool only_end)
 {
+	struct mmu_notifier_range _range, *range = &_range;
 	struct mmu_notifier *mn;
 	int id;
 
+	/*
+	 * The end call back will never be call if the start refused to go
+	 * through because of blockable was false so here assume that we
+	 * can block.
+	 */
+	range->blockable = true;
+	range->start = start;
+	range->end = end;
+	range->mm = mm;
+
 	id = srcu_read_lock(&srcu);
 	hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) {
 		/*
@@ -219,7 +236,7 @@ void __mmu_notifier_invalidate_range_end(struct mm_struct *mm,
 		if (!only_end && mn->ops->invalidate_range)
 			mn->ops->invalidate_range(mn, mm, start, end);
 		if (mn->ops->invalidate_range_end)
-			mn->ops->invalidate_range_end(mn, mm, start, end);
+			mn->ops->invalidate_range_end(mn, range);
 	}
 	srcu_read_unlock(&srcu, id);
 }