summary refs log tree commit diff
path: root/arch/x86/kvm/mmu/tdp_mmu.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/x86/kvm/mmu/tdp_mmu.c')
-rw-r--r--arch/x86/kvm/mmu/tdp_mmu.c740
1 files changed, 422 insertions, 318 deletions
diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index 34207b874886..88f69a6cc492 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -27,6 +27,15 @@ void kvm_mmu_init_tdp_mmu(struct kvm *kvm)
 	INIT_LIST_HEAD(&kvm->arch.tdp_mmu_pages);
 }
 
+static __always_inline void kvm_lockdep_assert_mmu_lock_held(struct kvm *kvm,
+							     bool shared)
+{
+	if (shared)
+		lockdep_assert_held_read(&kvm->mmu_lock);
+	else
+		lockdep_assert_held_write(&kvm->mmu_lock);
+}
+
 void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
 {
 	if (!kvm->arch.tdp_mmu_enabled)
@@ -41,32 +50,85 @@ void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
 	rcu_barrier();
 }
 
-static void tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root)
+static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+			  gfn_t start, gfn_t end, bool can_yield, bool flush,
+			  bool shared);
+
+static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
 {
-	if (kvm_mmu_put_root(kvm, root))
-		kvm_tdp_mmu_free_root(kvm, root);
+	free_page((unsigned long)sp->spt);
+	kmem_cache_free(mmu_page_header_cache, sp);
 }
 
-static inline bool tdp_mmu_next_root_valid(struct kvm *kvm,
-					   struct kvm_mmu_page *root)
+/*
+ * This is called through call_rcu in order to free TDP page table memory
+ * safely with respect to other kernel threads that may be operating on
+ * the memory.
+ * By only accessing TDP MMU page table memory in an RCU read critical
+ * section, and freeing it after a grace period, lockless access to that
+ * memory won't use it after it is freed.
+ */
+static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
 {
-	lockdep_assert_held_write(&kvm->mmu_lock);
+	struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
+					       rcu_head);
 
-	if (list_entry_is_head(root, &kvm->arch.tdp_mmu_roots, link))
-		return false;
+	tdp_mmu_free_sp(sp);
+}
 
-	kvm_mmu_get_root(kvm, root);
-	return true;
+void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
+			  bool shared)
+{
+	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+
+	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
 
+	if (!refcount_dec_and_test(&root->tdp_mmu_root_count))
+		return;
+
+	WARN_ON(!root->tdp_mmu_page);
+
+	spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+	list_del_rcu(&root->link);
+	spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
+
+	zap_gfn_range(kvm, root, 0, max_gfn, false, false, shared);
+
+	call_rcu(&root->rcu_head, tdp_mmu_free_sp_rcu_callback);
 }
 
-static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
-						     struct kvm_mmu_page *root)
+/*
+ * Finds the next valid root after root (or the first valid root if root
+ * is NULL), takes a reference on it, and returns that next root. If root
+ * is not NULL, this thread should have already taken a reference on it, and
+ * that reference will be dropped. If no valid root is found, this
+ * function will return NULL.
+ */
+static struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
+					      struct kvm_mmu_page *prev_root,
+					      bool shared)
 {
 	struct kvm_mmu_page *next_root;
 
-	next_root = list_next_entry(root, link);
-	tdp_mmu_put_root(kvm, root);
+	rcu_read_lock();
+
+	if (prev_root)
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &prev_root->link,
+						  typeof(*prev_root), link);
+	else
+		next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						   typeof(*next_root), link);
+
+	while (next_root && !kvm_tdp_mmu_get_root(kvm, next_root))
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+				&next_root->link, typeof(*next_root), link);
+
+	rcu_read_unlock();
+
+	if (prev_root)
+		kvm_tdp_mmu_put_root(kvm, prev_root, shared);
+
 	return next_root;
 }
 
@@ -75,35 +137,24 @@ static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
  * This makes it safe to release the MMU lock and yield within the loop, but
  * if exiting the loop early, the caller must drop the reference to the most
  * recent root. (Unless keeping a live reference is desirable.)
+ *
+ * If shared is set, this function is operating under the MMU lock in read
+ * mode. In the unlikely event that this thread must free a root, the lock
+ * will be temporarily dropped and reacquired in write mode.
  */
-#define for_each_tdp_mmu_root_yield_safe(_kvm, _root)				\
-	for (_root = list_first_entry(&_kvm->arch.tdp_mmu_roots,	\
-				      typeof(*_root), link);		\
-	     tdp_mmu_next_root_valid(_kvm, _root);			\
-	     _root = tdp_mmu_next_root(_kvm, _root))
-
-#define for_each_tdp_mmu_root(_kvm, _root)				\
-	list_for_each_entry(_root, &_kvm->arch.tdp_mmu_roots, link)
-
-static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-			  gfn_t start, gfn_t end, bool can_yield, bool flush);
-
-void kvm_tdp_mmu_free_root(struct kvm *kvm, struct kvm_mmu_page *root)
-{
-	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-
-	lockdep_assert_held_write(&kvm->mmu_lock);
-
-	WARN_ON(root->root_count);
-	WARN_ON(!root->tdp_mmu_page);
-
-	list_del(&root->link);
-
-	zap_gfn_range(kvm, root, 0, max_gfn, false, false);
-
-	free_page((unsigned long)root->spt);
-	kmem_cache_free(mmu_page_header_cache, root);
-}
+#define for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared)	\
+	for (_root = tdp_mmu_next_root(_kvm, NULL, _shared);		\
+	     _root;							\
+	     _root = tdp_mmu_next_root(_kvm, _root, _shared))		\
+		if (kvm_mmu_page_as_id(_root) != _as_id) {		\
+		} else
+
+#define for_each_tdp_mmu_root(_kvm, _root, _as_id)				\
+	list_for_each_entry_rcu(_root, &_kvm->arch.tdp_mmu_roots, link,		\
+				lockdep_is_held_type(&kvm->mmu_lock, 0) ||	\
+				lockdep_is_held(&kvm->arch.tdp_mmu_pages_lock))	\
+		if (kvm_mmu_page_as_id(_root) != _as_id) {		\
+		} else
 
 static union kvm_mmu_page_role page_role_for_level(struct kvm_vcpu *vcpu,
 						   int level)
@@ -137,81 +188,46 @@ static struct kvm_mmu_page *alloc_tdp_mmu_page(struct kvm_vcpu *vcpu, gfn_t gfn,
 	return sp;
 }
 
-static struct kvm_mmu_page *get_tdp_mmu_vcpu_root(struct kvm_vcpu *vcpu)
+hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
 {
 	union kvm_mmu_page_role role;
 	struct kvm *kvm = vcpu->kvm;
 	struct kvm_mmu_page *root;
 
-	role = page_role_for_level(vcpu, vcpu->arch.mmu->shadow_root_level);
+	lockdep_assert_held_write(&kvm->mmu_lock);
 
-	write_lock(&kvm->mmu_lock);
+	role = page_role_for_level(vcpu, vcpu->arch.mmu->shadow_root_level);
 
 	/* Check for an existing root before allocating a new one. */
-	for_each_tdp_mmu_root(kvm, root) {
-		if (root->role.word == role.word) {
-			kvm_mmu_get_root(kvm, root);
-			write_unlock(&kvm->mmu_lock);
-			return root;
-		}
+	for_each_tdp_mmu_root(kvm, root, kvm_mmu_role_as_id(role)) {
+		if (root->role.word == role.word &&
+		    kvm_tdp_mmu_get_root(kvm, root))
+			goto out;
 	}
 
 	root = alloc_tdp_mmu_page(vcpu, 0, vcpu->arch.mmu->shadow_root_level);
-	root->root_count = 1;
-
-	list_add(&root->link, &kvm->arch.tdp_mmu_roots);
-
-	write_unlock(&kvm->mmu_lock);
-
-	return root;
-}
-
-hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
-{
-	struct kvm_mmu_page *root;
+	refcount_set(&root->tdp_mmu_root_count, 1);
 
-	root = get_tdp_mmu_vcpu_root(vcpu);
-	if (!root)
-		return INVALID_PAGE;
+	spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+	list_add_rcu(&root->link, &kvm->arch.tdp_mmu_roots);
+	spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
 
+out:
 	return __pa(root->spt);
 }
 
-static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
-{
-	free_page((unsigned long)sp->spt);
-	kmem_cache_free(mmu_page_header_cache, sp);
-}
-
-/*
- * This is called through call_rcu in order to free TDP page table memory
- * safely with respect to other kernel threads that may be operating on
- * the memory.
- * By only accessing TDP MMU page table memory in an RCU read critical
- * section, and freeing it after a grace period, lockless access to that
- * memory won't use it after it is freed.
- */
-static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
-{
-	struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
-					       rcu_head);
-
-	tdp_mmu_free_sp(sp);
-}
-
 static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 				u64 old_spte, u64 new_spte, int level,
 				bool shared);
 
 static void handle_changed_spte_acc_track(u64 old_spte, u64 new_spte, int level)
 {
-	bool pfn_changed = spte_to_pfn(old_spte) != spte_to_pfn(new_spte);
-
 	if (!is_shadow_present_pte(old_spte) || !is_last_spte(old_spte, level))
 		return;
 
 	if (is_accessed_spte(old_spte) &&
-	    (!is_accessed_spte(new_spte) || pfn_changed))
+	    (!is_shadow_present_pte(new_spte) || !is_accessed_spte(new_spte) ||
+	     spte_to_pfn(old_spte) != spte_to_pfn(new_spte)))
 		kvm_set_pfn_accessed(spte_to_pfn(old_spte));
 }
 
@@ -455,7 +471,7 @@ static void __handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 
 
 	if (was_leaf && is_dirty_spte(old_spte) &&
-	    (!is_dirty_spte(new_spte) || pfn_changed))
+	    (!is_present || !is_dirty_spte(new_spte) || pfn_changed))
 		kvm_set_pfn_dirty(spte_to_pfn(old_spte));
 
 	/*
@@ -479,8 +495,9 @@ static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 }
 
 /*
- * tdp_mmu_set_spte_atomic - Set a TDP MMU SPTE atomically and handle the
- * associated bookkeeping
+ * tdp_mmu_set_spte_atomic_no_dirty_log - Set a TDP MMU SPTE atomically
+ * and handle the associated bookkeeping, but do not mark the page dirty
+ * in KVM's dirty bitmaps.
  *
  * @kvm: kvm instance
  * @iter: a tdp_iter instance currently on the SPTE that should be set
@@ -488,9 +505,9 @@ static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
  * Returns: true if the SPTE was set, false if it was not. If false is returned,
  *	    this function will have no side-effects.
  */
-static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
-					   struct tdp_iter *iter,
-					   u64 new_spte)
+static inline bool tdp_mmu_set_spte_atomic_no_dirty_log(struct kvm *kvm,
+							struct tdp_iter *iter,
+							u64 new_spte)
 {
 	lockdep_assert_held_read(&kvm->mmu_lock);
 
@@ -498,19 +515,32 @@ static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
 	 * Do not change removed SPTEs. Only the thread that froze the SPTE
 	 * may modify it.
 	 */
-	if (iter->old_spte == REMOVED_SPTE)
+	if (is_removed_spte(iter->old_spte))
 		return false;
 
 	if (cmpxchg64(rcu_dereference(iter->sptep), iter->old_spte,
 		      new_spte) != iter->old_spte)
 		return false;
 
-	handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
-			    new_spte, iter->level, true);
+	__handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
+			      new_spte, iter->level, true);
+	handle_changed_spte_acc_track(iter->old_spte, new_spte, iter->level);
 
 	return true;
 }
 
+static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
+					   struct tdp_iter *iter,
+					   u64 new_spte)
+{
+	if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, iter, new_spte))
+		return false;
+
+	handle_changed_spte_dirty_log(kvm, iter->as_id, iter->gfn,
+				      iter->old_spte, new_spte, iter->level);
+	return true;
+}
+
 static inline bool tdp_mmu_zap_spte_atomic(struct kvm *kvm,
 					   struct tdp_iter *iter)
 {
@@ -569,7 +599,7 @@ static inline void __tdp_mmu_set_spte(struct kvm *kvm, struct tdp_iter *iter,
 	 * should be used. If operating under the MMU lock in write mode, the
 	 * use of the removed SPTE should not be necessary.
 	 */
-	WARN_ON(iter->old_spte == REMOVED_SPTE);
+	WARN_ON(is_removed_spte(iter->old_spte));
 
 	WRITE_ONCE(*rcu_dereference(iter->sptep), new_spte);
 
@@ -634,7 +664,8 @@ static inline void tdp_mmu_set_spte_no_dirty_log(struct kvm *kvm,
  * Return false if a yield was not needed.
  */
 static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
-					     struct tdp_iter *iter, bool flush)
+					     struct tdp_iter *iter, bool flush,
+					     bool shared)
 {
 	/* Ensure forward progress has been made before yielding. */
 	if (iter->next_last_level_gfn == iter->yielded_gfn)
@@ -646,7 +677,11 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
 		if (flush)
 			kvm_flush_remote_tlbs(kvm);
 
-		cond_resched_rwlock_write(&kvm->mmu_lock);
+		if (shared)
+			cond_resched_rwlock_read(&kvm->mmu_lock);
+		else
+			cond_resched_rwlock_write(&kvm->mmu_lock);
+
 		rcu_read_lock();
 
 		WARN_ON(iter->gfn > iter->next_last_level_gfn);
@@ -664,24 +699,32 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
  * If can_yield is true, will release the MMU lock and reschedule if the
  * scheduler needs the CPU or there is contention on the MMU lock. If this
  * function cannot yield, it will not release the MMU lock or reschedule and
  * the caller must ensure it does not supply too large a GFN range, or the
- * operation can cause a soft lockup.  Note, in some use cases a flush may be
- * required by prior actions.  Ensure the pending flush is performed prior to
- * yielding.
+ * operation can cause a soft lockup.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU lock in write mode.
  */
 static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-			  gfn_t start, gfn_t end, bool can_yield, bool flush)
+			  gfn_t start, gfn_t end, bool can_yield, bool flush,
+			  bool shared)
 {
 	struct tdp_iter iter;
 
+	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
+
 	rcu_read_lock();
 
 	tdp_root_for_each_pte(iter, root, start, end) {
+retry:
 		if (can_yield &&
-		    tdp_mmu_iter_cond_resched(kvm, &iter, flush)) {
+		    tdp_mmu_iter_cond_resched(kvm, &iter, flush, shared)) {
 			flush = false;
 			continue;
 		}
@@ -699,8 +742,17 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 		    !is_last_spte(iter.old_spte, iter.level))
 			continue;
 
-		tdp_mmu_set_spte(kvm, &iter, 0);
-		flush = true;
+		if (!shared) {
+			tdp_mmu_set_spte(kvm, &iter, 0);
+			flush = true;
+		} else if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 	}
 
 	rcu_read_unlock();
@@ -712,15 +764,21 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU in write mode.
  */
-bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
-				 bool can_yield)
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
+				 gfn_t end, bool can_yield, bool flush,
+				 bool shared)
 {
 	struct kvm_mmu_page *root;
-	bool flush = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root)
-		flush = zap_gfn_range(kvm, root, start, end, can_yield, flush);
+	for_each_tdp_mmu_root_yield_safe(kvm, root, as_id, shared)
+		flush = zap_gfn_range(kvm, root, start, end, can_yield, flush,
+				      shared);
 
 	return flush;
 }
@@ -728,14 +786,116 @@ bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
 void kvm_tdp_mmu_zap_all(struct kvm *kvm)
 {
 	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-	bool flush;
+	bool flush = false;
+	int i;
+
+	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
+		flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, 0, max_gfn,
+						  flush, false);
+
+	if (flush)
+		kvm_flush_remote_tlbs(kvm);
+}
+
+static struct kvm_mmu_page *next_invalidated_root(struct kvm *kvm,
+						  struct kvm_mmu_page *prev_root)
+{
+	struct kvm_mmu_page *next_root;
+
+	if (prev_root)
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &prev_root->link,
+						  typeof(*prev_root), link);
+	else
+		next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						   typeof(*next_root), link);
+
+	while (next_root && !(next_root->role.invalid &&
+			      refcount_read(&next_root->tdp_mmu_root_count)))
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &next_root->link,
+						  typeof(*next_root), link);
+
+	return next_root;
+}
+
+/*
+ * Since kvm_tdp_mmu_zap_all_fast has acquired a reference to each
+ * invalidated root, they will not be freed until this function drops the
+ * reference. Before dropping that reference, tear down the paging
+ * structure so that whichever thread does drop the last reference
+ * only has to do a trivial amount of work. Since the roots are invalid,
+ * no new SPTEs should be created under them.
+ */
+void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm)
+{
+	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+	struct kvm_mmu_page *next_root;
+	struct kvm_mmu_page *root;
+	bool flush = false;
+
+	lockdep_assert_held_read(&kvm->mmu_lock);
+
+	rcu_read_lock();
+
+	root = next_invalidated_root(kvm, NULL);
+
+	while (root) {
+		next_root = next_invalidated_root(kvm, root);
+
+		rcu_read_unlock();
+
+		flush = zap_gfn_range(kvm, root, 0, max_gfn, true, flush,
+				      true);
+
+		/*
+		 * Put the reference acquired in
+		 * kvm_tdp_mmu_invalidate_roots
+		 */
+		kvm_tdp_mmu_put_root(kvm, root, true);
+
+		root = next_root;
+
+		rcu_read_lock();
+	}
+
+	rcu_read_unlock();
 
-	flush = kvm_tdp_mmu_zap_gfn_range(kvm, 0, max_gfn);
 	if (flush)
 		kvm_flush_remote_tlbs(kvm);
 }
 
 /*
+ * Mark each TDP MMU root as invalid so that other threads
+ * will drop their references and allow the root count to
+ * go to 0.
+ *
+ * Also take a reference on all roots so that this thread
+ * can do the bulk of the work required to free the roots
+ * once they are invalidated. Without this reference, a
+ * vCPU thread might drop the last reference to a root and
+ * get stuck with tearing down the entire paging structure.
+ *
+ * Roots which have a zero refcount should be skipped as
+ * they're already being torn down.
+ * Already invalid roots should be referenced again so that
+ * they aren't freed before kvm_tdp_mmu_zap_all_fast is
+ * done with them.
+ *
+ * This has essentially the same effect for the TDP MMU
+ * as updating mmu_valid_gen does for the shadow MMU.
+ */
+void kvm_tdp_mmu_invalidate_all_roots(struct kvm *kvm)
+{
+	struct kvm_mmu_page *root;
+
+	lockdep_assert_held_write(&kvm->mmu_lock);
+	list_for_each_entry(root, &kvm->arch.tdp_mmu_roots, link)
+		if (refcount_inc_not_zero(&root->tdp_mmu_root_count))
+			root->role.invalid = true;
+}
+
+/*
  * Installs a last-level SPTE to handle a TDP page fault.
  * (NPT/EPT violation/misconfiguration)
  */
@@ -777,12 +937,11 @@ static int tdp_mmu_map_handle_target_level(struct kvm_vcpu *vcpu, int write,
 		trace_mark_mmio_spte(rcu_dereference(iter->sptep), iter->gfn,
 				     new_spte);
 		ret = RET_PF_EMULATE;
-	} else
+	} else {
 		trace_kvm_mmu_set_spte(iter->level, iter->gfn,
 				       rcu_dereference(iter->sptep));
+	}
 
-	trace_kvm_mmu_set_spte(iter->level, iter->gfn,
-			       rcu_dereference(iter->sptep));
 	if (!prefault)
 		vcpu->stat.pf_fixed++;
 
@@ -882,199 +1041,139 @@ int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 	return ret;
 }
 
-static __always_inline int
-kvm_tdp_mmu_handle_hva_range(struct kvm *kvm,
-			     unsigned long start,
-			     unsigned long end,
-			     unsigned long data,
-			     int (*handler)(struct kvm *kvm,
-					    struct kvm_memory_slot *slot,
-					    struct kvm_mmu_page *root,
-					    gfn_t start,
-					    gfn_t end,
-					    unsigned long data))
+bool kvm_tdp_mmu_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range,
+				 bool flush)
 {
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
 	struct kvm_mmu_page *root;
-	int ret = 0;
-	int as_id;
-
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		as_id = kvm_mmu_page_as_id(root);
-		slots = __kvm_memslots(kvm, as_id);
-		kvm_for_each_memslot(memslot, slots) {
-			unsigned long hva_start, hva_end;
-			gfn_t gfn_start, gfn_end;
-
-			hva_start = max(start, memslot->userspace_addr);
-			hva_end = min(end, memslot->userspace_addr +
-				      (memslot->npages << PAGE_SHIFT));
-			if (hva_start >= hva_end)
-				continue;
-			/*
-			 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-			 * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-			 */
-			gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-			gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
 
-			ret |= handler(kvm, memslot, root, gfn_start,
-				       gfn_end, data);
-		}
-	}
+	for_each_tdp_mmu_root(kvm, root, range->slot->as_id)
+		flush |= zap_gfn_range(kvm, root, range->start, range->end,
+				       range->may_block, flush, false);
 
-	return ret;
+	return flush;
 }
 
-static int zap_gfn_range_hva_wrapper(struct kvm *kvm,
-				     struct kvm_memory_slot *slot,
-				     struct kvm_mmu_page *root, gfn_t start,
-				     gfn_t end, unsigned long unused)
-{
-	return zap_gfn_range(kvm, root, start, end, false, false);
-}
+typedef bool (*tdp_handler_t)(struct kvm *kvm, struct tdp_iter *iter,
+			      struct kvm_gfn_range *range);
 
-int kvm_tdp_mmu_zap_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end)
+static __always_inline bool kvm_tdp_mmu_handle_gfn(struct kvm *kvm,
+						   struct kvm_gfn_range *range,
+						   tdp_handler_t handler)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-					    zap_gfn_range_hva_wrapper);
+	struct kvm_mmu_page *root;
+	struct tdp_iter iter;
+	bool ret = false;
+
+	rcu_read_lock();
+
+	/*
+	 * Don't support rescheduling, none of the MMU notifiers that funnel
+	 * into this helper allow blocking; it'd be dead, wasteful code.
+	 */
+	for_each_tdp_mmu_root(kvm, root, range->slot->as_id) {
+		tdp_root_for_each_leaf_pte(iter, root, range->start, range->end)
+			ret |= handler(kvm, &iter, range);
+	}
+
+	rcu_read_unlock();
+
+	return ret;
 }
 
 /*
  * Mark the SPTEs range of GFNs [start, end) unaccessed and return non-zero
  * if any of the GFNs in the range have been accessed.
  */
-static int age_gfn_range(struct kvm *kvm, struct kvm_memory_slot *slot,
-			 struct kvm_mmu_page *root, gfn_t start, gfn_t end,
-			 unsigned long unused)
+static bool age_gfn_range(struct kvm *kvm, struct tdp_iter *iter,
+			  struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-	int young = 0;
 	u64 new_spte = 0;
 
-	rcu_read_lock();
+	/* If we have a non-accessed entry we don't need to change the pte. */
+	if (!is_accessed_spte(iter->old_spte))
+		return false;
 
-	tdp_root_for_each_leaf_pte(iter, root, start, end) {
+	new_spte = iter->old_spte;
+
+	if (spte_ad_enabled(new_spte)) {
+		new_spte &= ~shadow_accessed_mask;
+	} else {
 		/*
-		 * If we have a non-accessed entry we don't need to change the
-		 * pte.
+		 * Capture the dirty status of the page, so that it doesn't get
+		 * lost when the SPTE is marked for access tracking.
 		 */
-		if (!is_accessed_spte(iter.old_spte))
-			continue;
-
-		new_spte = iter.old_spte;
-
-		if (spte_ad_enabled(new_spte)) {
-			clear_bit((ffs(shadow_accessed_mask) - 1),
-				  (unsigned long *)&new_spte);
-		} else {
-			/*
-			 * Capture the dirty status of the page, so that it doesn't get
-			 * lost when the SPTE is marked for access tracking.
-			 */
-			if (is_writable_pte(new_spte))
-				kvm_set_pfn_dirty(spte_to_pfn(new_spte));
+		if (is_writable_pte(new_spte))
+			kvm_set_pfn_dirty(spte_to_pfn(new_spte));
 
-			new_spte = mark_spte_for_access_track(new_spte);
-		}
-		new_spte &= ~shadow_dirty_mask;
-
-		tdp_mmu_set_spte_no_acc_track(kvm, &iter, new_spte);
-		young = 1;
-
-		trace_kvm_age_page(iter.gfn, iter.level, slot, young);
+		new_spte = mark_spte_for_access_track(new_spte);
 	}
 
-	rcu_read_unlock();
+	tdp_mmu_set_spte_no_acc_track(kvm, iter, new_spte);
 
-	return young;
+	return true;
 }
 
-int kvm_tdp_mmu_age_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end)
+bool kvm_tdp_mmu_age_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-					    age_gfn_range);
+	return kvm_tdp_mmu_handle_gfn(kvm, range, age_gfn_range);
 }
 
-static int test_age_gfn(struct kvm *kvm, struct kvm_memory_slot *slot,
-			struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-			unsigned long unused2)
+static bool test_age_gfn(struct kvm *kvm, struct tdp_iter *iter,
+			 struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-
-	tdp_root_for_each_leaf_pte(iter, root, gfn, gfn + 1)
-		if (is_accessed_spte(iter.old_spte))
-			return 1;
-
-	return 0;
+	return is_accessed_spte(iter->old_spte);
 }
 
-int kvm_tdp_mmu_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_tdp_mmu_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, hva, hva + 1, 0,
-					    test_age_gfn);
+	return kvm_tdp_mmu_handle_gfn(kvm, range, test_age_gfn);
 }
 
-/*
- * Handle the changed_pte MMU notifier for the TDP MMU.
- * data is a pointer to the new pte_t mapping the HVA specified by the MMU
- * notifier.
- * Returns non-zero if a flush is needed before releasing the MMU lock.
- */
-static int set_tdp_spte(struct kvm *kvm, struct kvm_memory_slot *slot,
-			struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-			unsigned long data)
+static bool set_spte_gfn(struct kvm *kvm, struct tdp_iter *iter,
+			 struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-	pte_t *ptep = (pte_t *)data;
-	kvm_pfn_t new_pfn;
 	u64 new_spte;
-	int need_flush = 0;
-
-	rcu_read_lock();
 
-	WARN_ON(pte_huge(*ptep));
+	/* Huge pages aren't expected to be modified without first being zapped. */
+	WARN_ON(pte_huge(range->pte) || range->start + 1 != range->end);
 
-	new_pfn = pte_pfn(*ptep);
-
-	tdp_root_for_each_pte(iter, root, gfn, gfn + 1) {
-		if (iter.level != PG_LEVEL_4K)
-			continue;
-
-		if (!is_shadow_present_pte(iter.old_spte))
-			break;
-
-		tdp_mmu_set_spte(kvm, &iter, 0);
-
-		kvm_flush_remote_tlbs_with_address(kvm, iter.gfn, 1);
+	if (iter->level != PG_LEVEL_4K ||
+	    !is_shadow_present_pte(iter->old_spte))
+		return false;
 
-		if (!pte_write(*ptep)) {
-			new_spte = kvm_mmu_changed_pte_notifier_make_spte(
-					iter.old_spte, new_pfn);
+	/*
+	 * Note, when changing a read-only SPTE, it's not strictly necessary to
+	 * zero the SPTE before setting the new PFN, but doing so preserves the
+	 * invariant that the PFN of a present * leaf SPTE can never change.
+	 * See __handle_changed_spte().
+	 */
+	tdp_mmu_set_spte(kvm, iter, 0);
 
-			tdp_mmu_set_spte(kvm, &iter, new_spte);
-		}
+	if (!pte_write(range->pte)) {
+		new_spte = kvm_mmu_changed_pte_notifier_make_spte(iter->old_spte,
+								  pte_pfn(range->pte));
 
-		need_flush = 1;
+		tdp_mmu_set_spte(kvm, iter, new_spte);
 	}
 
-	if (need_flush)
-		kvm_flush_remote_tlbs_with_address(kvm, gfn, 1);
-
-	rcu_read_unlock();
-
-	return 0;
+	return true;
 }
 
-int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
-			     pte_t *host_ptep)
+/*
+ * Handle the changed_pte MMU notifier for the TDP MMU.
+ * data is a pointer to the new pte_t mapping the HVA specified by the MMU
+ * notifier.
+ * Returns non-zero if a flush is needed before releasing the MMU lock.
+ */
+bool kvm_tdp_mmu_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, address, address + 1,
-					    (unsigned long)host_ptep,
-					    set_tdp_spte);
+	bool flush = kvm_tdp_mmu_handle_gfn(kvm, range, set_spte_gfn);
+
+	/* FIXME: return 'flush' instead of flushing here. */
+	if (flush)
+		kvm_flush_remote_tlbs_with_address(kvm, range->start, 1);
+
+	return false;
 }
 
 /*
@@ -1095,7 +1194,8 @@ static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
 	for_each_tdp_pte_min_level(iter, root->spt, root->role.level,
 				   min_level, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
 
 		if (!is_shadow_present_pte(iter.old_spte) ||
@@ -1105,7 +1205,15 @@ static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
 		new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
 
-		tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+		if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+							  new_spte)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 		spte_set = true;
 	}
 
@@ -1122,17 +1230,13 @@ bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
 			     int min_level)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
 		spte_set |= wrprot_gfn_range(kvm, root, slot->base_gfn,
 			     slot->base_gfn + slot->npages, min_level);
-	}
 
 	return spte_set;
 }
@@ -1154,7 +1258,8 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 	rcu_read_lock();
 
 	tdp_root_for_each_leaf_pte(iter, root, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
 
 		if (spte_ad_need_write_protect(iter.old_spte)) {
@@ -1169,7 +1274,15 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 				continue;
 		}
 
-		tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+		if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+							  new_spte)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 		spte_set = true;
 	}
 
@@ -1187,17 +1300,13 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 bool kvm_tdp_mmu_clear_dirty_slot(struct kvm *kvm, struct kvm_memory_slot *slot)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
 		spte_set |= clear_dirty_gfn_range(kvm, root, slot->base_gfn,
 				slot->base_gfn + slot->npages);
-	}
 
 	return spte_set;
 }
@@ -1259,37 +1368,32 @@ void kvm_tdp_mmu_clear_dirty_pt_masked(struct kvm *kvm,
 				       bool wrprot)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	for_each_tdp_mmu_root(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
-
+	for_each_tdp_mmu_root(kvm, root, slot->as_id)
 		clear_dirty_pt_masked(kvm, root, gfn, mask, wrprot);
-	}
 }
 
 /*
  * Clear leaf entries which could be replaced by large mappings, for
  * GFNs within the slot.
  */
-static void zap_collapsible_spte_range(struct kvm *kvm,
+static bool zap_collapsible_spte_range(struct kvm *kvm,
 				       struct kvm_mmu_page *root,
-				       struct kvm_memory_slot *slot)
+				       const struct kvm_memory_slot *slot,
+				       bool flush)
 {
 	gfn_t start = slot->base_gfn;
 	gfn_t end = start + slot->npages;
 	struct tdp_iter iter;
 	kvm_pfn_t pfn;
-	bool spte_set = false;
 
 	rcu_read_lock();
 
 	tdp_root_for_each_pte(iter, root, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, spte_set)) {
-			spte_set = false;
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, flush, true)) {
+			flush = false;
 			continue;
 		}
 
@@ -1303,38 +1407,43 @@ static void zap_collapsible_spte_range(struct kvm *kvm,
 							    pfn, PG_LEVEL_NUM))
 			continue;
 
-		tdp_mmu_set_spte(kvm, &iter, 0);
-
-		spte_set = true;
+		if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
+		flush = true;
 	}
 
 	rcu_read_unlock();
-	if (spte_set)
-		kvm_flush_remote_tlbs(kvm);
+
+	return flush;
 }
 
 /*
  * Clear non-leaf entries (and free associated page tables) which could
  * be replaced by large mappings, for GFNs within the slot.
  */
-void kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
-				       struct kvm_memory_slot *slot)
+bool kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
+				       const struct kvm_memory_slot *slot,
+				       bool flush)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
-		zap_collapsible_spte_range(kvm, root, slot);
-	}
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
+		flush = zap_collapsible_spte_range(kvm, root, slot, flush);
+
+	return flush;
 }
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
@@ -1351,7 +1460,7 @@ static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
 			break;
 
 		new_spte = iter.old_spte &
-			~(PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE);
+			~(PT_WRITABLE_MASK | shadow_mmu_writable_mask);
 
 		tdp_mmu_set_spte(kvm, &iter, new_spte);
 		spte_set = true;
@@ -1364,24 +1473,19 @@ static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 bool kvm_tdp_mmu_write_protect_gfn(struct kvm *kvm,
 				   struct kvm_memory_slot *slot, gfn_t gfn)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	for_each_tdp_mmu_root(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
-
+	for_each_tdp_mmu_root(kvm, root, slot->as_id)
 		spte_set |= write_protect_gfn(kvm, root, gfn);
-	}
+
 	return spte_set;
 }