summary refs log tree commit diff
path: root/drivers/iommu/intel-svm.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu/intel-svm.c')
-rw-r--r--drivers/iommu/intel-svm.c43
1 files changed, 31 insertions, 12 deletions
diff --git a/drivers/iommu/intel-svm.c b/drivers/iommu/intel-svm.c
index 89d4d47d0ab3..817be769e94f 100644
--- a/drivers/iommu/intel-svm.c
+++ b/drivers/iommu/intel-svm.c
@@ -269,11 +269,10 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
 	struct intel_iommu *iommu = intel_svm_device_to_iommu(dev);
 	struct intel_svm_dev *sdev;
 	struct intel_svm *svm = NULL;
+	struct mm_struct *mm = NULL;
 	int pasid_max;
 	int ret;
 
-	BUG_ON(pasid && !current->mm);
-
 	if (WARN_ON(!iommu))
 		return -EINVAL;
 
@@ -284,12 +283,20 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
 	} else
 		pasid_max = 1 << 20;
 
+	if ((flags & SVM_FLAG_SUPERVISOR_MODE)) {
+		if (!ecap_srs(iommu->ecap))
+			return -EINVAL;
+	} else if (pasid) {
+		mm = get_task_mm(current);
+		BUG_ON(!mm);
+	}
+
 	mutex_lock(&pasid_mutex);
 	if (pasid && !(flags & SVM_FLAG_PRIVATE_PASID)) {
 		int i;
 
 		idr_for_each_entry(&iommu->pasid_idr, svm, i) {
-			if (svm->mm != current->mm ||
+			if (svm->mm != mm ||
 			    (svm->flags & SVM_FLAG_PRIVATE_PASID))
 				continue;
 
@@ -355,17 +362,22 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
 		}
 		svm->pasid = ret;
 		svm->notifier.ops = &intel_mmuops;
-		svm->mm = get_task_mm(current);
+		svm->mm = mm;
 		svm->flags = flags;
 		INIT_LIST_HEAD_RCU(&svm->devs);
 		ret = -ENOMEM;
-		if (!svm->mm || (ret = mmu_notifier_register(&svm->notifier, svm->mm))) {
-			idr_remove(&svm->iommu->pasid_idr, svm->pasid);
-			kfree(svm);
-			kfree(sdev);
-			goto out;
-		}
-		iommu->pasid_table[svm->pasid].val = (u64)__pa(svm->mm->pgd) | 1;
+		if (mm) {
+			ret = mmu_notifier_register(&svm->notifier, mm);
+			if (ret) {
+				idr_remove(&svm->iommu->pasid_idr, svm->pasid);
+				kfree(svm);
+				kfree(sdev);
+				goto out;
+			}
+			iommu->pasid_table[svm->pasid].val = (u64)__pa(mm->pgd) | 1;
+			mm = NULL;
+		} else
+			iommu->pasid_table[svm->pasid].val = (u64)__pa(init_mm.pgd) | 1 | (1ULL << 11);
 		wmb();
 	}
 	list_add_rcu(&sdev->list, &svm->devs);
@@ -375,6 +387,8 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
 	ret = 0;
  out:
 	mutex_unlock(&pasid_mutex);
+	if (mm)
+		mmput(mm);
 	return ret;
 }
 EXPORT_SYMBOL_GPL(intel_svm_bind_mm);
@@ -416,7 +430,8 @@ int intel_svm_unbind_mm(struct device *dev, int pasid)
 					mmu_notifier_unregister(&svm->notifier, svm->mm);
 
 					idr_remove(&svm->iommu->pasid_idr, svm->pasid);
-					mmput(svm->mm);
+					if (svm->mm)
+						mmput(svm->mm);
 					/* We mandate that no page faults may be outstanding
 					 * for the PASID when intel_svm_unbind_mm() is called.
 					 * If that is not obeyed, subtle errors will happen.
@@ -500,6 +515,10 @@ static irqreturn_t prq_event_thread(int irq, void *d)
 		}
 
 		result = QI_RESP_INVALID;
+		/* Since we're using init_mm.pgd directly, we should never take
+		 * any faults on kernel addresses. */
+		if (!svm->mm)
+			goto bad_req;
 		down_read(&svm->mm->mmap_sem);
 		vma = find_extend_vma(svm->mm, address);
 		if (!vma || address < vma->vm_start)