summary refs log tree commit diff
path: root/drivers/iommu/intel-svm.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu/intel-svm.c')
-rw-r--r--drivers/iommu/intel-svm.c33
1 files changed, 27 insertions, 6 deletions
diff --git a/drivers/iommu/intel-svm.c b/drivers/iommu/intel-svm.c
index 50464833d0b8..97a818992d6d 100644
--- a/drivers/iommu/intel-svm.c
+++ b/drivers/iommu/intel-svm.c
@@ -249,12 +249,30 @@ static void intel_flush_pasid_dev(struct intel_svm *svm, struct intel_svm_dev *s
 static void intel_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
 	struct intel_svm *svm = container_of(mn, struct intel_svm, notifier);
+	struct intel_svm_dev *sdev;
 
+	/* This might end up being called from exit_mmap(), *before* the page
+	 * tables are cleared. And __mmu_notifier_release() will delete us from
+	 * the list of notifiers so that our invalidate_range() callback doesn't
+	 * get called when the page tables are cleared. So we need to protect
+	 * against hardware accessing those page tables.
+	 *
+	 * We do it by clearing the entry in the PASID table and then flushing
+	 * the IOTLB and the PASID table caches. This might upset hardware;
+	 * perhaps we'll want to point the PASID to a dummy PGD (like the zero
+	 * page) so that we end up taking a fault that the hardware really
+	 * *has* to handle gracefully without affecting other processes.
+	 */
 	svm->iommu->pasid_table[svm->pasid].val = 0;
+	wmb();
+
+	rcu_read_lock();
+	list_for_each_entry_rcu(sdev, &svm->devs, list) {
+		intel_flush_pasid_dev(svm, sdev, svm->pasid);
+		intel_flush_svm_range_dev(svm, sdev, 0, -1, 0, !svm->mm);
+	}
+	rcu_read_unlock();
 
-	/* There's no need to do any flush because we can't get here if there
-	 * are any devices left anyway. */
-	WARN_ON(!list_empty(&svm->devs));
 }
 
 static const struct mmu_notifier_ops intel_mmuops = {
@@ -379,7 +397,6 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
 				goto out;
 			}
 			iommu->pasid_table[svm->pasid].val = (u64)__pa(mm->pgd) | 1;
-			mm = NULL;
 		} else
 			iommu->pasid_table[svm->pasid].val = (u64)__pa(init_mm.pgd) | 1 | (1ULL << 11);
 		wmb();
@@ -442,11 +459,11 @@ int intel_svm_unbind_mm(struct device *dev, int pasid)
 				kfree_rcu(sdev, rcu);
 
 				if (list_empty(&svm->devs)) {
-					mmu_notifier_unregister(&svm->notifier, svm->mm);
 
 					idr_remove(&svm->iommu->pasid_idr, svm->pasid);
 					if (svm->mm)
-						mmput(svm->mm);
+						mmu_notifier_unregister(&svm->notifier, svm->mm);
+
 					/* We mandate that no page faults may be outstanding
 					 * for the PASID when intel_svm_unbind_mm() is called.
 					 * If that is not obeyed, subtle errors will happen.
@@ -551,6 +568,9 @@ static irqreturn_t prq_event_thread(int irq, void *d)
 		 * any faults on kernel addresses. */
 		if (!svm->mm)
 			goto bad_req;
+		/* If the mm is already defunct, don't handle faults. */
+		if (!atomic_inc_not_zero(&svm->mm->mm_users))
+			goto bad_req;
 		down_read(&svm->mm->mmap_sem);
 		vma = find_extend_vma(svm->mm, address);
 		if (!vma || address < vma->vm_start)
@@ -567,6 +587,7 @@ static irqreturn_t prq_event_thread(int irq, void *d)
 		result = QI_RESP_SUCCESS;
 	invalid:
 		up_read(&svm->mm->mmap_sem);
+		mmput(svm->mm);
 	bad_req:
 		/* Accounting for major/minor faults? */
 		rcu_read_lock();