summary refs log tree commit diff
path: root/arch/x86
diff options
context:
space:
mode:
Diffstat (limited to 'arch/x86')
-rw-r--r--arch/x86/include/asm/cpufeatures.h1
-rw-r--r--arch/x86/include/asm/kvm_host.h67
-rw-r--r--arch/x86/include/asm/mem_encrypt.h1
-rw-r--r--arch/x86/include/asm/svm.h4
-rw-r--r--arch/x86/include/asm/vmx.h1
-rw-r--r--arch/x86/include/uapi/asm/vmx.h1
-rw-r--r--arch/x86/kernel/kvm.c128
-rw-r--r--arch/x86/kvm/Makefile2
-rw-r--r--arch/x86/kvm/cpuid.c98
-rw-r--r--arch/x86/kvm/cpuid.h155
-rw-r--r--arch/x86/kvm/emulate.c80
-rw-r--r--arch/x86/kvm/kvm_cache_regs.h19
-rw-r--r--arch/x86/kvm/lapic.c8
-rw-r--r--arch/x86/kvm/mmu.h23
-rw-r--r--arch/x86/kvm/mmu/mmu.c637
-rw-r--r--arch/x86/kvm/mmu/mmu_audit.c2
-rw-r--r--arch/x86/kvm/mmu/mmu_internal.h44
-rw-r--r--arch/x86/kvm/mmu/paging_tmpl.h3
-rw-r--r--arch/x86/kvm/mmu/spte.c159
-rw-r--r--arch/x86/kvm/mmu/spte.h141
-rw-r--r--arch/x86/kvm/mmu/tdp_mmu.c740
-rw-r--r--arch/x86/kvm/mmu/tdp_mmu.h51
-rw-r--r--arch/x86/kvm/reverse_cpuid.h186
-rw-r--r--arch/x86/kvm/svm/avic.c24
-rw-r--r--arch/x86/kvm/svm/nested.c573
-rw-r--r--arch/x86/kvm/svm/sev.c922
-rw-r--r--arch/x86/kvm/svm/svm.c1107
-rw-r--r--arch/x86/kvm/svm/svm.h91
-rw-r--r--arch/x86/kvm/svm/vmenter.S47
-rw-r--r--arch/x86/kvm/vmx/nested.c83
-rw-r--r--arch/x86/kvm/vmx/nested.h5
-rw-r--r--arch/x86/kvm/vmx/sgx.c502
-rw-r--r--arch/x86/kvm/vmx/sgx.h34
-rw-r--r--arch/x86/kvm/vmx/vmcs12.c1
-rw-r--r--arch/x86/kvm/vmx/vmcs12.h4
-rw-r--r--arch/x86/kvm/vmx/vmx.c432
-rw-r--r--arch/x86/kvm/vmx/vmx.h39
-rw-r--r--arch/x86/kvm/vmx/vmx_ops.h4
-rw-r--r--arch/x86/kvm/x86.c214
-rw-r--r--arch/x86/kvm/x86.h18
-rw-r--r--arch/x86/mm/mem_encrypt.c10
-rw-r--r--arch/x86/mm/mem_encrypt_identity.c1
42 files changed, 4212 insertions, 2450 deletions
diff --git a/arch/x86/include/asm/cpufeatures.h b/arch/x86/include/asm/cpufeatures.h
index 3c94316169a3..ac37830ae941 100644
--- a/arch/x86/include/asm/cpufeatures.h
+++ b/arch/x86/include/asm/cpufeatures.h
@@ -340,6 +340,7 @@
 #define X86_FEATURE_AVIC		(15*32+13) /* Virtual Interrupt Controller */
 #define X86_FEATURE_V_VMSAVE_VMLOAD	(15*32+15) /* Virtual VMSAVE VMLOAD */
 #define X86_FEATURE_VGIF		(15*32+16) /* Virtual GIF */
+#define X86_FEATURE_V_SPEC_CTRL		(15*32+20) /* Virtual SPEC_CTRL */
 #define X86_FEATURE_SVME_ADDR_CHK	(15*32+28) /* "" SVME addr check */
 
 /* Intel-defined CPU features, CPUID level 0x00000007:0 (ECX), word 16 */
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 10eca9e8f7f6..cbbcee0a84f9 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -221,12 +221,22 @@ enum x86_intercept_stage;
 #define DR7_FIXED_1	0x00000400
 #define DR7_VOLATILE	0xffff2bff
 
+#define KVM_GUESTDBG_VALID_MASK \
+	(KVM_GUESTDBG_ENABLE | \
+	KVM_GUESTDBG_SINGLESTEP | \
+	KVM_GUESTDBG_USE_HW_BP | \
+	KVM_GUESTDBG_USE_SW_BP | \
+	KVM_GUESTDBG_INJECT_BP | \
+	KVM_GUESTDBG_INJECT_DB)
+
+
 #define PFERR_PRESENT_BIT 0
 #define PFERR_WRITE_BIT 1
 #define PFERR_USER_BIT 2
 #define PFERR_RSVD_BIT 3
 #define PFERR_FETCH_BIT 4
 #define PFERR_PK_BIT 5
+#define PFERR_SGX_BIT 15
 #define PFERR_GUEST_FINAL_BIT 32
 #define PFERR_GUEST_PAGE_BIT 33
 
@@ -236,6 +246,7 @@ enum x86_intercept_stage;
 #define PFERR_RSVD_MASK (1U << PFERR_RSVD_BIT)
 #define PFERR_FETCH_MASK (1U << PFERR_FETCH_BIT)
 #define PFERR_PK_MASK (1U << PFERR_PK_BIT)
+#define PFERR_SGX_MASK (1U << PFERR_SGX_BIT)
 #define PFERR_GUEST_FINAL_MASK (1ULL << PFERR_GUEST_FINAL_BIT)
 #define PFERR_GUEST_PAGE_MASK (1ULL << PFERR_GUEST_PAGE_BIT)
 
@@ -1054,6 +1065,9 @@ struct kvm_arch {
 	u32 user_space_msr_mask;
 	struct kvm_x86_msr_filter __rcu *msr_filter;
 
+	/* Guest can access the SGX PROVISIONKEY. */
+	bool sgx_provisioning_allowed;
+
 	struct kvm_pmu_event_filter __rcu *pmu_event_filter;
 	struct task_struct *nx_lpage_recovery_thread;
 
@@ -1068,25 +1082,36 @@ struct kvm_arch {
 	bool tdp_mmu_enabled;
 
 	/*
-	 * List of struct kvmp_mmu_pages being used as roots.
+	 * List of struct kvm_mmu_pages being used as roots.
 	 * All struct kvm_mmu_pages in the list should have
 	 * tdp_mmu_page set.
-	 * All struct kvm_mmu_pages in the list should have a positive
-	 * root_count except when a thread holds the MMU lock and is removing
-	 * an entry from the list.
+	 *
+	 * For reads, this list is protected by:
+	 *	the MMU lock in read mode + RCU or
+	 *	the MMU lock in write mode
+	 *
+	 * For writes, this list is protected by:
+	 *	the MMU lock in read mode + the tdp_mmu_pages_lock or
+	 *	the MMU lock in write mode
+	 *
+	 * Roots will remain in the list until their tdp_mmu_root_count
+	 * drops to zero, at which point the thread that decremented the
+	 * count to zero should removed the root from the list and clean
+	 * it up, freeing the root after an RCU grace period.
 	 */
 	struct list_head tdp_mmu_roots;
 
 	/*
 	 * List of struct kvmp_mmu_pages not being used as roots.
 	 * All struct kvm_mmu_pages in the list should have
-	 * tdp_mmu_page set and a root_count of 0.
+	 * tdp_mmu_page set and a tdp_mmu_root_count of 0.
 	 */
 	struct list_head tdp_mmu_pages;
 
 	/*
 	 * Protects accesses to the following fields when the MMU lock
 	 * is held in read mode:
+	 *  - tdp_mmu_roots (above)
 	 *  - tdp_mmu_pages (above)
 	 *  - the link field of struct kvm_mmu_pages used by the TDP MMU
 	 *  - lpage_disallowed_mmu_pages
@@ -1143,6 +1168,9 @@ struct kvm_vcpu_stat {
 	u64 req_event;
 	u64 halt_poll_success_ns;
 	u64 halt_poll_fail_ns;
+	u64 nested_run;
+	u64 directed_yield_attempted;
+	u64 directed_yield_successful;
 };
 
 struct x86_instruction_info;
@@ -1269,8 +1297,8 @@ struct kvm_x86_ops {
 	int (*set_identity_map_addr)(struct kvm *kvm, u64 ident_addr);
 	u64 (*get_mt_mask)(struct kvm_vcpu *vcpu, gfn_t gfn, bool is_mmio);
 
-	void (*load_mmu_pgd)(struct kvm_vcpu *vcpu, unsigned long pgd,
-			     int pgd_level);
+	void (*load_mmu_pgd)(struct kvm_vcpu *vcpu, hpa_t root_hpa,
+			     int root_level);
 
 	bool (*has_wbinvd_exit)(void);
 
@@ -1339,6 +1367,7 @@ struct kvm_x86_ops {
 	int (*mem_enc_op)(struct kvm *kvm, void __user *argp);
 	int (*mem_enc_reg_region)(struct kvm *kvm, struct kvm_enc_region *argp);
 	int (*mem_enc_unreg_region)(struct kvm *kvm, struct kvm_enc_region *argp);
+	int (*vm_copy_enc_context_from)(struct kvm *kvm, unsigned int source_fd);
 
 	int (*get_msr_feature)(struct kvm_msr_entry *entry);
 
@@ -1357,6 +1386,7 @@ struct kvm_x86_ops {
 struct kvm_x86_nested_ops {
 	int (*check_events)(struct kvm_vcpu *vcpu);
 	bool (*hv_timer_pending)(struct kvm_vcpu *vcpu);
+	void (*triple_fault)(struct kvm_vcpu *vcpu);
 	int (*get_state)(struct kvm_vcpu *vcpu,
 			 struct kvm_nested_state __user *user_kvm_nested_state,
 			 unsigned user_data_size);
@@ -1428,9 +1458,6 @@ void kvm_mmu_destroy(struct kvm_vcpu *vcpu);
 int kvm_mmu_create(struct kvm_vcpu *vcpu);
 void kvm_mmu_init_vm(struct kvm *kvm);
 void kvm_mmu_uninit_vm(struct kvm *kvm);
-void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
-		u64 dirty_mask, u64 nx_mask, u64 x_mask, u64 p_mask,
-		u64 acc_track_mask, u64 me_mask);
 
 void kvm_mmu_reset_context(struct kvm_vcpu *vcpu);
 void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
@@ -1440,8 +1467,6 @@ void kvm_mmu_zap_collapsible_sptes(struct kvm *kvm,
 				   const struct kvm_memory_slot *memslot);
 void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
 				   struct kvm_memory_slot *memslot);
-void kvm_mmu_slot_largepage_remove_write_access(struct kvm *kvm,
-					struct kvm_memory_slot *memslot);
 void kvm_mmu_zap_all(struct kvm *kvm);
 void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm, u64 gen);
 unsigned long kvm_mmu_calculate_default_mmu_pages(struct kvm *kvm);
@@ -1538,6 +1563,11 @@ int kvm_get_msr(struct kvm_vcpu *vcpu, u32 index, u64 *data);
 int kvm_set_msr(struct kvm_vcpu *vcpu, u32 index, u64 data);
 int kvm_emulate_rdmsr(struct kvm_vcpu *vcpu);
 int kvm_emulate_wrmsr(struct kvm_vcpu *vcpu);
+int kvm_emulate_as_nop(struct kvm_vcpu *vcpu);
+int kvm_emulate_invd(struct kvm_vcpu *vcpu);
+int kvm_emulate_mwait(struct kvm_vcpu *vcpu);
+int kvm_handle_invalid_op(struct kvm_vcpu *vcpu);
+int kvm_emulate_monitor(struct kvm_vcpu *vcpu);
 
 int kvm_fast_pio(struct kvm_vcpu *vcpu, int size, unsigned short port, int in);
 int kvm_emulate_cpuid(struct kvm_vcpu *vcpu);
@@ -1566,14 +1596,14 @@ void kvm_get_dr(struct kvm_vcpu *vcpu, int dr, unsigned long *val);
 unsigned long kvm_get_cr8(struct kvm_vcpu *vcpu);
 void kvm_lmsw(struct kvm_vcpu *vcpu, unsigned long msw);
 void kvm_get_cs_db_l_bits(struct kvm_vcpu *vcpu, int *db, int *l);
-int kvm_set_xcr(struct kvm_vcpu *vcpu, u32 index, u64 xcr);
+int kvm_emulate_xsetbv(struct kvm_vcpu *vcpu);
 
 int kvm_get_msr_common(struct kvm_vcpu *vcpu, struct msr_data *msr);
 int kvm_set_msr_common(struct kvm_vcpu *vcpu, struct msr_data *msr);
 
 unsigned long kvm_get_rflags(struct kvm_vcpu *vcpu);
 void kvm_set_rflags(struct kvm_vcpu *vcpu, unsigned long rflags);
-bool kvm_rdpmc(struct kvm_vcpu *vcpu);
+int kvm_emulate_rdpmc(struct kvm_vcpu *vcpu);
 
 void kvm_queue_exception(struct kvm_vcpu *vcpu, unsigned nr);
 void kvm_queue_exception_e(struct kvm_vcpu *vcpu, unsigned nr, u32 error_code);
@@ -1614,9 +1644,6 @@ void kvm_update_dr7(struct kvm_vcpu *vcpu);
 
 int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn);
 void __kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu);
-int kvm_mmu_load(struct kvm_vcpu *vcpu);
-void kvm_mmu_unload(struct kvm_vcpu *vcpu);
-void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu);
 void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 			ulong roots_to_free);
 gpa_t translate_nested_gpa(struct kvm_vcpu *vcpu, gpa_t gpa, u32 access,
@@ -1735,11 +1762,7 @@ asmlinkage void kvm_spurious_fault(void);
 	_ASM_EXTABLE(666b, 667b)
 
 #define KVM_ARCH_WANT_MMU_NOTIFIER
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-			unsigned flags);
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end);
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva);
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte);
+
 int kvm_cpu_has_injectable_intr(struct kvm_vcpu *v);
 int kvm_cpu_has_interrupt(struct kvm_vcpu *vcpu);
 int kvm_cpu_has_extint(struct kvm_vcpu *v);
diff --git a/arch/x86/include/asm/mem_encrypt.h b/arch/x86/include/asm/mem_encrypt.h
index 31c4df123aa0..9c80c68d75b5 100644
--- a/arch/x86/include/asm/mem_encrypt.h
+++ b/arch/x86/include/asm/mem_encrypt.h
@@ -20,7 +20,6 @@
 
 extern u64 sme_me_mask;
 extern u64 sev_status;
-extern bool sev_enabled;
 
 void sme_encrypt_execute(unsigned long encrypted_kernel_vaddr,
 			 unsigned long decrypted_kernel_vaddr,
diff --git a/arch/x86/include/asm/svm.h b/arch/x86/include/asm/svm.h
index 1c561945b426..772e60efe243 100644
--- a/arch/x86/include/asm/svm.h
+++ b/arch/x86/include/asm/svm.h
@@ -269,7 +269,9 @@ struct vmcb_save_area {
 	 * SEV-ES guests when referenced through the GHCB or for
 	 * saving to the host save area.
 	 */
-	u8 reserved_7[80];
+	u8 reserved_7[72];
+	u32 spec_ctrl;		/* Guest version of SPEC_CTRL at 0x2E0 */
+	u8 reserved_7b[4];
 	u32 pkru;
 	u8 reserved_7a[20];
 	u64 reserved_8;		/* rax already available at 0x01f8 */
diff --git a/arch/x86/include/asm/vmx.h b/arch/x86/include/asm/vmx.h
index 358707f60d99..0ffaa3156a4e 100644
--- a/arch/x86/include/asm/vmx.h
+++ b/arch/x86/include/asm/vmx.h
@@ -373,6 +373,7 @@ enum vmcs_field {
 #define GUEST_INTR_STATE_MOV_SS		0x00000002
 #define GUEST_INTR_STATE_SMI		0x00000004
 #define GUEST_INTR_STATE_NMI		0x00000008
+#define GUEST_INTR_STATE_ENCLAVE_INTR	0x00000010
 
 /* GUEST_ACTIVITY_STATE flags */
 #define GUEST_ACTIVITY_ACTIVE		0
diff --git a/arch/x86/include/uapi/asm/vmx.h b/arch/x86/include/uapi/asm/vmx.h
index b8e650a985e3..946d761adbd3 100644
--- a/arch/x86/include/uapi/asm/vmx.h
+++ b/arch/x86/include/uapi/asm/vmx.h
@@ -27,6 +27,7 @@
 
 
 #define VMX_EXIT_REASONS_FAILED_VMENTRY         0x80000000
+#define VMX_EXIT_REASONS_SGX_ENCLAVE_MODE	0x08000000
 
 #define EXIT_REASON_EXCEPTION_NMI       0
 #define EXIT_REASON_EXTERNAL_INTERRUPT  1
diff --git a/arch/x86/kernel/kvm.c b/arch/x86/kernel/kvm.c
index 5d32fa477a62..d307c22e5c18 100644
--- a/arch/x86/kernel/kvm.c
+++ b/arch/x86/kernel/kvm.c
@@ -451,6 +451,10 @@ static void __init sev_map_percpu_data(void)
 	}
 }
 
+#ifdef CONFIG_SMP
+
+static DEFINE_PER_CPU(cpumask_var_t, __pv_cpu_mask);
+
 static bool pv_tlb_flush_supported(void)
 {
 	return (kvm_para_has_feature(KVM_FEATURE_PV_TLB_FLUSH) &&
@@ -458,10 +462,6 @@ static bool pv_tlb_flush_supported(void)
 		kvm_para_has_feature(KVM_FEATURE_STEAL_TIME));
 }
 
-static DEFINE_PER_CPU(cpumask_var_t, __pv_cpu_mask);
-
-#ifdef CONFIG_SMP
-
 static bool pv_ipi_supported(void)
 {
 	return kvm_para_has_feature(KVM_FEATURE_PV_SEND_IPI);
@@ -574,6 +574,54 @@ static void kvm_smp_send_call_func_ipi(const struct cpumask *mask)
 	}
 }
 
+static void kvm_flush_tlb_multi(const struct cpumask *cpumask,
+			const struct flush_tlb_info *info)
+{
+	u8 state;
+	int cpu;
+	struct kvm_steal_time *src;
+	struct cpumask *flushmask = this_cpu_cpumask_var_ptr(__pv_cpu_mask);
+
+	cpumask_copy(flushmask, cpumask);
+	/*
+	 * We have to call flush only on online vCPUs. And
+	 * queue flush_on_enter for pre-empted vCPUs
+	 */
+	for_each_cpu(cpu, flushmask) {
+		/*
+		 * The local vCPU is never preempted, so we do not explicitly
+		 * skip check for local vCPU - it will never be cleared from
+		 * flushmask.
+		 */
+		src = &per_cpu(steal_time, cpu);
+		state = READ_ONCE(src->preempted);
+		if ((state & KVM_VCPU_PREEMPTED)) {
+			if (try_cmpxchg(&src->preempted, &state,
+					state | KVM_VCPU_FLUSH_TLB))
+				__cpumask_clear_cpu(cpu, flushmask);
+		}
+	}
+
+	native_flush_tlb_multi(flushmask, info);
+}
+
+static __init int kvm_alloc_cpumask(void)
+{
+	int cpu;
+
+	if (!kvm_para_available() || nopv)
+		return 0;
+
+	if (pv_tlb_flush_supported() || pv_ipi_supported())
+		for_each_possible_cpu(cpu) {
+			zalloc_cpumask_var_node(per_cpu_ptr(&__pv_cpu_mask, cpu),
+				GFP_KERNEL, cpu_to_node(cpu));
+		}
+
+	return 0;
+}
+arch_initcall(kvm_alloc_cpumask);
+
 static void __init kvm_smp_prepare_boot_cpu(void)
 {
 	/*
@@ -611,38 +659,8 @@ static int kvm_cpu_down_prepare(unsigned int cpu)
 	local_irq_enable();
 	return 0;
 }
-#endif
-
-static void kvm_flush_tlb_multi(const struct cpumask *cpumask,
-			const struct flush_tlb_info *info)
-{
-	u8 state;
-	int cpu;
-	struct kvm_steal_time *src;
-	struct cpumask *flushmask = this_cpu_cpumask_var_ptr(__pv_cpu_mask);
-
-	cpumask_copy(flushmask, cpumask);
-	/*
-	 * We have to call flush only on online vCPUs. And
-	 * queue flush_on_enter for pre-empted vCPUs
-	 */
-	for_each_cpu(cpu, flushmask) {
-		/*
-		 * The local vCPU is never preempted, so we do not explicitly
-		 * skip check for local vCPU - it will never be cleared from
-		 * flushmask.
-		 */
-		src = &per_cpu(steal_time, cpu);
-		state = READ_ONCE(src->preempted);
-		if ((state & KVM_VCPU_PREEMPTED)) {
-			if (try_cmpxchg(&src->preempted, &state,
-					state | KVM_VCPU_FLUSH_TLB))
-				__cpumask_clear_cpu(cpu, flushmask);
-		}
-	}
 
-	native_flush_tlb_multi(flushmask, info);
-}
+#endif
 
 static void __init kvm_guest_init(void)
 {
@@ -658,12 +676,6 @@ static void __init kvm_guest_init(void)
 		static_call_update(pv_steal_clock, kvm_steal_clock);
 	}
 
-	if (pv_tlb_flush_supported()) {
-		pv_ops.mmu.flush_tlb_multi = kvm_flush_tlb_multi;
-		pv_ops.mmu.tlb_remove_table = tlb_remove_table;
-		pr_info("KVM setup pv remote TLB flush\n");
-	}
-
 	if (kvm_para_has_feature(KVM_FEATURE_PV_EOI))
 		apic_set_eoi_write(kvm_guest_apic_eoi_write);
 
@@ -673,6 +685,12 @@ static void __init kvm_guest_init(void)
 	}
 
 #ifdef CONFIG_SMP
+	if (pv_tlb_flush_supported()) {
+		pv_ops.mmu.flush_tlb_multi = kvm_flush_tlb_multi;
+		pv_ops.mmu.tlb_remove_table = tlb_remove_table;
+		pr_info("KVM setup pv remote TLB flush\n");
+	}
+
 	smp_ops.smp_prepare_boot_cpu = kvm_smp_prepare_boot_cpu;
 	if (pv_sched_yield_supported()) {
 		smp_ops.send_call_func_ipi = kvm_smp_send_call_func_ipi;
@@ -739,7 +757,7 @@ static uint32_t __init kvm_detect(void)
 
 static void __init kvm_apic_init(void)
 {
-#if defined(CONFIG_SMP)
+#ifdef CONFIG_SMP
 	if (pv_ipi_supported())
 		kvm_setup_pv_ipi();
 #endif
@@ -799,32 +817,6 @@ static __init int activate_jump_labels(void)
 }
 arch_initcall(activate_jump_labels);
 
-static __init int kvm_alloc_cpumask(void)
-{
-	int cpu;
-	bool alloc = false;
-
-	if (!kvm_para_available() || nopv)
-		return 0;
-
-	if (pv_tlb_flush_supported())
-		alloc = true;
-
-#if defined(CONFIG_SMP)
-	if (pv_ipi_supported())
-		alloc = true;
-#endif
-
-	if (alloc)
-		for_each_possible_cpu(cpu) {
-			zalloc_cpumask_var_node(per_cpu_ptr(&__pv_cpu_mask, cpu),
-				GFP_KERNEL, cpu_to_node(cpu));
-		}
-
-	return 0;
-}
-arch_initcall(kvm_alloc_cpumask);
-
 #ifdef CONFIG_PARAVIRT_SPINLOCKS
 
 /* Kick a cpu by its apicid. Used to wake up a halted vcpu */
diff --git a/arch/x86/kvm/Makefile b/arch/x86/kvm/Makefile
index eafc4d601f25..c589db5d91b3 100644
--- a/arch/x86/kvm/Makefile
+++ b/arch/x86/kvm/Makefile
@@ -23,6 +23,8 @@ kvm-$(CONFIG_KVM_XEN)	+= xen.o
 
 kvm-intel-y		+= vmx/vmx.o vmx/vmenter.o vmx/pmu_intel.o vmx/vmcs12.o \
 			   vmx/evmcs.o vmx/nested.o vmx/posted_intr.o
+kvm-intel-$(CONFIG_X86_SGX_KVM)	+= vmx/sgx.o
+
 kvm-amd-y		+= svm/svm.o svm/vmenter.o svm/pmu.o svm/nested.o svm/avic.o svm/sev.o
 
 obj-$(CONFIG_KVM)	+= kvm.o
diff --git a/arch/x86/kvm/cpuid.c b/arch/x86/kvm/cpuid.c
index c02466a1410b..19606a341888 100644
--- a/arch/x86/kvm/cpuid.c
+++ b/arch/x86/kvm/cpuid.c
@@ -18,6 +18,7 @@
 #include <asm/processor.h>
 #include <asm/user.h>
 #include <asm/fpu/xstate.h>
+#include <asm/sgx.h>
 #include "cpuid.h"
 #include "lapic.h"
 #include "mmu.h"
@@ -28,7 +29,7 @@
  * Unlike "struct cpuinfo_x86.x86_capability", kvm_cpu_caps doesn't need to be
  * aligned to sizeof(unsigned long) because it's not accessed via bitops.
  */
-u32 kvm_cpu_caps[NCAPINTS] __read_mostly;
+u32 kvm_cpu_caps[NR_KVM_CPU_CAPS] __read_mostly;
 EXPORT_SYMBOL_GPL(kvm_cpu_caps);
 
 static u32 xstate_required_size(u64 xstate_bv, bool compacted)
@@ -53,6 +54,7 @@ static u32 xstate_required_size(u64 xstate_bv, bool compacted)
 }
 
 #define F feature_bit
+#define SF(name) (boot_cpu_has(X86_FEATURE_##name) ? F(name) : 0)
 
 static inline struct kvm_cpuid_entry2 *cpuid_entry2_find(
 	struct kvm_cpuid_entry2 *entries, int nent, u32 function, u32 index)
@@ -170,6 +172,21 @@ static void kvm_vcpu_after_set_cpuid(struct kvm_vcpu *vcpu)
 		vcpu->arch.guest_supported_xcr0 =
 			(best->eax | ((u64)best->edx << 32)) & supported_xcr0;
 
+	/*
+	 * Bits 127:0 of the allowed SECS.ATTRIBUTES (CPUID.0x12.0x1) enumerate
+	 * the supported XSAVE Feature Request Mask (XFRM), i.e. the enclave's
+	 * requested XCR0 value.  The enclave's XFRM must be a subset of XCRO
+	 * at the time of EENTER, thus adjust the allowed XFRM by the guest's
+	 * supported XCR0.  Similar to XCR0 handling, FP and SSE are forced to
+	 * '1' even on CPUs that don't support XSAVE.
+	 */
+	best = kvm_find_cpuid_entry(vcpu, 0x12, 0x1);
+	if (best) {
+		best->ecx &= vcpu->arch.guest_supported_xcr0 & 0xffffffff;
+		best->edx &= vcpu->arch.guest_supported_xcr0 >> 32;
+		best->ecx |= XFEATURE_MASK_FPSSE;
+	}
+
 	kvm_update_pv_runtime(vcpu);
 
 	vcpu->arch.maxphyaddr = cpuid_query_maxphyaddr(vcpu);
@@ -347,13 +364,13 @@ out:
 	return r;
 }
 
-static __always_inline void kvm_cpu_cap_mask(enum cpuid_leafs leaf, u32 mask)
+/* Mask kvm_cpu_caps for @leaf with the raw CPUID capabilities of this CPU. */
+static __always_inline void __kvm_cpu_cap_mask(unsigned int leaf)
 {
 	const struct cpuid_reg cpuid = x86_feature_cpuid(leaf * 32);
 	struct kvm_cpuid_entry2 entry;
 
 	reverse_cpuid_check(leaf);
-	kvm_cpu_caps[leaf] &= mask;
 
 	cpuid_count(cpuid.function, cpuid.index,
 		    &entry.eax, &entry.ebx, &entry.ecx, &entry.edx);
@@ -361,6 +378,27 @@ static __always_inline void kvm_cpu_cap_mask(enum cpuid_leafs leaf, u32 mask)
 	kvm_cpu_caps[leaf] &= *__cpuid_entry_get_reg(&entry, cpuid.reg);
 }
 
+static __always_inline
+void kvm_cpu_cap_init_scattered(enum kvm_only_cpuid_leafs leaf, u32 mask)
+{
+	/* Use kvm_cpu_cap_mask for non-scattered leafs. */
+	BUILD_BUG_ON(leaf < NCAPINTS);
+
+	kvm_cpu_caps[leaf] = mask;
+
+	__kvm_cpu_cap_mask(leaf);
+}
+
+static __always_inline void kvm_cpu_cap_mask(enum cpuid_leafs leaf, u32 mask)
+{
+	/* Use kvm_cpu_cap_init_scattered for scattered leafs. */
+	BUILD_BUG_ON(leaf >= NCAPINTS);
+
+	kvm_cpu_caps[leaf] &= mask;
+
+	__kvm_cpu_cap_mask(leaf);
+}
+
 void kvm_set_cpu_caps(void)
 {
 	unsigned int f_nx = is_efer_nx() ? F(NX) : 0;
@@ -371,12 +409,13 @@ void kvm_set_cpu_caps(void)
 	unsigned int f_gbpages = 0;
 	unsigned int f_lm = 0;
 #endif
+	memset(kvm_cpu_caps, 0, sizeof(kvm_cpu_caps));
 
-	BUILD_BUG_ON(sizeof(kvm_cpu_caps) >
+	BUILD_BUG_ON(sizeof(kvm_cpu_caps) - (NKVMCAPINTS * sizeof(*kvm_cpu_caps)) >
 		     sizeof(boot_cpu_data.x86_capability));
 
 	memcpy(&kvm_cpu_caps, &boot_cpu_data.x86_capability,
-	       sizeof(kvm_cpu_caps));
+	       sizeof(kvm_cpu_caps) - (NKVMCAPINTS * sizeof(*kvm_cpu_caps)));
 
 	kvm_cpu_cap_mask(CPUID_1_ECX,
 		/*
@@ -407,7 +446,7 @@ void kvm_set_cpu_caps(void)
 	);
 
 	kvm_cpu_cap_mask(CPUID_7_0_EBX,
-		F(FSGSBASE) | F(BMI1) | F(HLE) | F(AVX2) | F(SMEP) |
+		F(FSGSBASE) | F(SGX) | F(BMI1) | F(HLE) | F(AVX2) | F(SMEP) |
 		F(BMI2) | F(ERMS) | F(INVPCID) | F(RTM) | 0 /*MPX*/ | F(RDSEED) |
 		F(ADX) | F(SMAP) | F(AVX512IFMA) | F(AVX512F) | F(AVX512PF) |
 		F(AVX512ER) | F(AVX512CD) | F(CLFLUSHOPT) | F(CLWB) | F(AVX512DQ) |
@@ -418,7 +457,8 @@ void kvm_set_cpu_caps(void)
 		F(AVX512VBMI) | F(LA57) | F(PKU) | 0 /*OSPKE*/ | F(RDPID) |
 		F(AVX512_VPOPCNTDQ) | F(UMIP) | F(AVX512_VBMI2) | F(GFNI) |
 		F(VAES) | F(VPCLMULQDQ) | F(AVX512_VNNI) | F(AVX512_BITALG) |
-		F(CLDEMOTE) | F(MOVDIRI) | F(MOVDIR64B) | 0 /*WAITPKG*/
+		F(CLDEMOTE) | F(MOVDIRI) | F(MOVDIR64B) | 0 /*WAITPKG*/ |
+		F(SGX_LC)
 	);
 	/* Set LA57 based on hardware capability. */
 	if (cpuid_ecx(7) & F(LA57))
@@ -457,6 +497,10 @@ void kvm_set_cpu_caps(void)
 		F(XSAVEOPT) | F(XSAVEC) | F(XGETBV1) | F(XSAVES)
 	);
 
+	kvm_cpu_cap_init_scattered(CPUID_12_EAX,
+		SF(SGX1) | SF(SGX2)
+	);
+
 	kvm_cpu_cap_mask(CPUID_8000_0001_ECX,
 		F(LAHF_LM) | F(CMP_LEGACY) | 0 /*SVM*/ | 0 /* ExtApicSpace */ |
 		F(CR8_LEGACY) | F(ABM) | F(SSE4A) | F(MISALIGNSSE) |
@@ -514,6 +558,10 @@ void kvm_set_cpu_caps(void)
 	 */
 	kvm_cpu_cap_mask(CPUID_8000_000A_EDX, 0);
 
+	kvm_cpu_cap_mask(CPUID_8000_001F_EAX,
+		0 /* SME */ | F(SEV) | 0 /* VM_PAGE_FLUSH */ | F(SEV_ES) |
+		F(SME_COHERENT));
+
 	kvm_cpu_cap_mask(CPUID_C000_0001_EDX,
 		F(XSTORE) | F(XSTORE_EN) | F(XCRYPT) | F(XCRYPT_EN) |
 		F(ACE2) | F(ACE2_EN) | F(PHE) | F(PHE_EN) |
@@ -778,6 +826,38 @@ static inline int __do_cpuid_func(struct kvm_cpuid_array *array, u32 function)
 			entry->edx = 0;
 		}
 		break;
+	case 0x12:
+		/* Intel SGX */
+		if (!kvm_cpu_cap_has(X86_FEATURE_SGX)) {
+			entry->eax = entry->ebx = entry->ecx = entry->edx = 0;
+			break;
+		}
+
+		/*
+		 * Index 0: Sub-features, MISCSELECT (a.k.a extended features)
+		 * and max enclave sizes.   The SGX sub-features and MISCSELECT
+		 * are restricted by kernel and KVM capabilities (like most
+		 * feature flags), while enclave size is unrestricted.
+		 */
+		cpuid_entry_override(entry, CPUID_12_EAX);
+		entry->ebx &= SGX_MISC_EXINFO;
+
+		entry = do_host_cpuid(array, function, 1);
+		if (!entry)
+			goto out;
+
+		/*
+		 * Index 1: SECS.ATTRIBUTES.  ATTRIBUTES are restricted a la
+		 * feature flags.  Advertise all supported flags, including
+		 * privileged attributes that require explicit opt-in from
+		 * userspace.  ATTRIBUTES.XFRM is not adjusted as userspace is
+		 * expected to derive it from supported XCR0.
+		 */
+		entry->eax &= SGX_ATTR_DEBUG | SGX_ATTR_MODE64BIT |
+			      SGX_ATTR_PROVISIONKEY | SGX_ATTR_EINITTOKENKEY |
+			      SGX_ATTR_KSS;
+		entry->ebx &= 0;
+		break;
 	/* Intel PT */
 	case 0x14:
 		if (!kvm_cpu_cap_has(X86_FEATURE_INTEL_PT)) {
@@ -869,8 +949,10 @@ static inline int __do_cpuid_func(struct kvm_cpuid_array *array, u32 function)
 		break;
 	/* Support memory encryption cpuid if host supports it */
 	case 0x8000001F:
-		if (!boot_cpu_has(X86_FEATURE_SEV))
+		if (!kvm_cpu_cap_has(X86_FEATURE_SEV))
 			entry->eax = entry->ebx = entry->ecx = entry->edx = 0;
+		else
+			cpuid_entry_override(entry, CPUID_8000_001F_EAX);
 		break;
 	/*Add support for Centaur's CPUID instruction*/
 	case 0xC0000000:
diff --git a/arch/x86/kvm/cpuid.h b/arch/x86/kvm/cpuid.h
index 2a0c5064497f..c99edfff7f82 100644
--- a/arch/x86/kvm/cpuid.h
+++ b/arch/x86/kvm/cpuid.h
@@ -3,11 +3,12 @@
 #define ARCH_X86_KVM_CPUID_H
 
 #include "x86.h"
+#include "reverse_cpuid.h"
 #include <asm/cpu.h>
 #include <asm/processor.h>
 #include <uapi/asm/kvm_para.h>
 
-extern u32 kvm_cpu_caps[NCAPINTS] __read_mostly;
+extern u32 kvm_cpu_caps[NR_KVM_CPU_CAPS] __read_mostly;
 void kvm_set_cpu_caps(void);
 
 void kvm_update_cpuid_runtime(struct kvm_vcpu *vcpu);
@@ -58,144 +59,8 @@ static inline bool page_address_valid(struct kvm_vcpu *vcpu, gpa_t gpa)
 	return kvm_vcpu_is_legal_aligned_gpa(vcpu, gpa, PAGE_SIZE);
 }
 
-struct cpuid_reg {
-	u32 function;
-	u32 index;
-	int reg;
-};
-
-static const struct cpuid_reg reverse_cpuid[] = {
-	[CPUID_1_EDX]         = {         1, 0, CPUID_EDX},
-	[CPUID_8000_0001_EDX] = {0x80000001, 0, CPUID_EDX},
-	[CPUID_8086_0001_EDX] = {0x80860001, 0, CPUID_EDX},
-	[CPUID_1_ECX]         = {         1, 0, CPUID_ECX},
-	[CPUID_C000_0001_EDX] = {0xc0000001, 0, CPUID_EDX},
-	[CPUID_8000_0001_ECX] = {0x80000001, 0, CPUID_ECX},
-	[CPUID_7_0_EBX]       = {         7, 0, CPUID_EBX},
-	[CPUID_D_1_EAX]       = {       0xd, 1, CPUID_EAX},
-	[CPUID_8000_0008_EBX] = {0x80000008, 0, CPUID_EBX},
-	[CPUID_6_EAX]         = {         6, 0, CPUID_EAX},
-	[CPUID_8000_000A_EDX] = {0x8000000a, 0, CPUID_EDX},
-	[CPUID_7_ECX]         = {         7, 0, CPUID_ECX},
-	[CPUID_8000_0007_EBX] = {0x80000007, 0, CPUID_EBX},
-	[CPUID_7_EDX]         = {         7, 0, CPUID_EDX},
-	[CPUID_7_1_EAX]       = {         7, 1, CPUID_EAX},
-};
-
-/*
- * Reverse CPUID and its derivatives can only be used for hardware-defined
- * feature words, i.e. words whose bits directly correspond to a CPUID leaf.
- * Retrieving a feature bit or masking guest CPUID from a Linux-defined word
- * is nonsensical as the bit number/mask is an arbitrary software-defined value
- * and can't be used by KVM to query/control guest capabilities.  And obviously
- * the leaf being queried must have an entry in the lookup table.
- */
-static __always_inline void reverse_cpuid_check(unsigned int x86_leaf)
-{
-	BUILD_BUG_ON(x86_leaf == CPUID_LNX_1);
-	BUILD_BUG_ON(x86_leaf == CPUID_LNX_2);
-	BUILD_BUG_ON(x86_leaf == CPUID_LNX_3);
-	BUILD_BUG_ON(x86_leaf == CPUID_LNX_4);
-	BUILD_BUG_ON(x86_leaf >= ARRAY_SIZE(reverse_cpuid));
-	BUILD_BUG_ON(reverse_cpuid[x86_leaf].function == 0);
-}
-
-/*
- * Retrieve the bit mask from an X86_FEATURE_* definition.  Features contain
- * the hardware defined bit number (stored in bits 4:0) and a software defined
- * "word" (stored in bits 31:5).  The word is used to index into arrays of
- * bit masks that hold the per-cpu feature capabilities, e.g. this_cpu_has().
- */
-static __always_inline u32 __feature_bit(int x86_feature)
-{
-	reverse_cpuid_check(x86_feature / 32);
-	return 1 << (x86_feature & 31);
-}
-
-#define feature_bit(name)  __feature_bit(X86_FEATURE_##name)
-
-static __always_inline struct cpuid_reg x86_feature_cpuid(unsigned int x86_feature)
-{
-	unsigned int x86_leaf = x86_feature / 32;
-
-	reverse_cpuid_check(x86_leaf);
-	return reverse_cpuid[x86_leaf];
-}
-
-static __always_inline u32 *__cpuid_entry_get_reg(struct kvm_cpuid_entry2 *entry,
-						  u32 reg)
-{
-	switch (reg) {
-	case CPUID_EAX:
-		return &entry->eax;
-	case CPUID_EBX:
-		return &entry->ebx;
-	case CPUID_ECX:
-		return &entry->ecx;
-	case CPUID_EDX:
-		return &entry->edx;
-	default:
-		BUILD_BUG();
-		return NULL;
-	}
-}
-
-static __always_inline u32 *cpuid_entry_get_reg(struct kvm_cpuid_entry2 *entry,
-						unsigned int x86_feature)
-{
-	const struct cpuid_reg cpuid = x86_feature_cpuid(x86_feature);
-
-	return __cpuid_entry_get_reg(entry, cpuid.reg);
-}
-
-static __always_inline u32 cpuid_entry_get(struct kvm_cpuid_entry2 *entry,
-					   unsigned int x86_feature)
-{
-	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
-
-	return *reg & __feature_bit(x86_feature);
-}
-
-static __always_inline bool cpuid_entry_has(struct kvm_cpuid_entry2 *entry,
-					    unsigned int x86_feature)
-{
-	return cpuid_entry_get(entry, x86_feature);
-}
-
-static __always_inline void cpuid_entry_clear(struct kvm_cpuid_entry2 *entry,
-					      unsigned int x86_feature)
-{
-	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
-
-	*reg &= ~__feature_bit(x86_feature);
-}
-
-static __always_inline void cpuid_entry_set(struct kvm_cpuid_entry2 *entry,
-					    unsigned int x86_feature)
-{
-	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
-
-	*reg |= __feature_bit(x86_feature);
-}
-
-static __always_inline void cpuid_entry_change(struct kvm_cpuid_entry2 *entry,
-					       unsigned int x86_feature,
-					       bool set)
-{
-	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
-
-	/*
-	 * Open coded instead of using cpuid_entry_{clear,set}() to coerce the
-	 * compiler into using CMOV instead of Jcc when possible.
-	 */
-	if (set)
-		*reg |= __feature_bit(x86_feature);
-	else
-		*reg &= ~__feature_bit(x86_feature);
-}
-
 static __always_inline void cpuid_entry_override(struct kvm_cpuid_entry2 *entry,
-						 enum cpuid_leafs leaf)
+						 unsigned int leaf)
 {
 	u32 *reg = cpuid_entry_get_reg(entry, leaf * 32);
 
@@ -248,6 +113,14 @@ static inline bool guest_cpuid_is_amd_or_hygon(struct kvm_vcpu *vcpu)
 		is_guest_vendor_hygon(best->ebx, best->ecx, best->edx));
 }
 
+static inline bool guest_cpuid_is_intel(struct kvm_vcpu *vcpu)
+{
+	struct kvm_cpuid_entry2 *best;
+
+	best = kvm_find_cpuid_entry(vcpu, 0, 0);
+	return best && is_guest_vendor_intel(best->ebx, best->ecx, best->edx);
+}
+
 static inline int guest_cpuid_family(struct kvm_vcpu *vcpu)
 {
 	struct kvm_cpuid_entry2 *best;
@@ -308,7 +181,7 @@ static inline bool cpuid_fault_enabled(struct kvm_vcpu *vcpu)
 
 static __always_inline void kvm_cpu_cap_clear(unsigned int x86_feature)
 {
-	unsigned int x86_leaf = x86_feature / 32;
+	unsigned int x86_leaf = __feature_leaf(x86_feature);
 
 	reverse_cpuid_check(x86_leaf);
 	kvm_cpu_caps[x86_leaf] &= ~__feature_bit(x86_feature);
@@ -316,7 +189,7 @@ static __always_inline void kvm_cpu_cap_clear(unsigned int x86_feature)
 
 static __always_inline void kvm_cpu_cap_set(unsigned int x86_feature)
 {
-	unsigned int x86_leaf = x86_feature / 32;
+	unsigned int x86_leaf = __feature_leaf(x86_feature);
 
 	reverse_cpuid_check(x86_leaf);
 	kvm_cpu_caps[x86_leaf] |= __feature_bit(x86_feature);
@@ -324,7 +197,7 @@ static __always_inline void kvm_cpu_cap_set(unsigned int x86_feature)
 
 static __always_inline u32 kvm_cpu_cap_get(unsigned int x86_feature)
 {
-	unsigned int x86_leaf = x86_feature / 32;
+	unsigned int x86_leaf = __feature_leaf(x86_feature);
 
 	reverse_cpuid_check(x86_leaf);
 	return kvm_cpu_caps[x86_leaf] & __feature_bit(x86_feature);
diff --git a/arch/x86/kvm/emulate.c b/arch/x86/kvm/emulate.c
index cdd2a2b6550e..77e1c89a95a7 100644
--- a/arch/x86/kvm/emulate.c
+++ b/arch/x86/kvm/emulate.c
@@ -4220,7 +4220,7 @@ static bool valid_cr(int nr)
 	}
 }
 
-static int check_cr_read(struct x86_emulate_ctxt *ctxt)
+static int check_cr_access(struct x86_emulate_ctxt *ctxt)
 {
 	if (!valid_cr(ctxt->modrm_reg))
 		return emulate_ud(ctxt);
@@ -4228,80 +4228,6 @@ static int check_cr_read(struct x86_emulate_ctxt *ctxt)
 	return X86EMUL_CONTINUE;
 }
 
-static int check_cr_write(struct x86_emulate_ctxt *ctxt)
-{
-	u64 new_val = ctxt->src.val64;
-	int cr = ctxt->modrm_reg;
-	u64 efer = 0;
-
-	static u64 cr_reserved_bits[] = {
-		0xffffffff00000000ULL,
-		0, 0, 0, /* CR3 checked later */
-		CR4_RESERVED_BITS,
-		0, 0, 0,
-		CR8_RESERVED_BITS,
-	};
-
-	if (!valid_cr(cr))
-		return emulate_ud(ctxt);
-
-	if (new_val & cr_reserved_bits[cr])
-		return emulate_gp(ctxt, 0);
-
-	switch (cr) {
-	case 0: {
-		u64 cr4;
-		if (((new_val & X86_CR0_PG) && !(new_val & X86_CR0_PE)) ||
-		    ((new_val & X86_CR0_NW) && !(new_val & X86_CR0_CD)))
-			return emulate_gp(ctxt, 0);
-
-		cr4 = ctxt->ops->get_cr(ctxt, 4);
-		ctxt->ops->get_msr(ctxt, MSR_EFER, &efer);
-
-		if ((new_val & X86_CR0_PG) && (efer & EFER_LME) &&
-		    !(cr4 & X86_CR4_PAE))
-			return emulate_gp(ctxt, 0);
-
-		break;
-		}
-	case 3: {
-		u64 rsvd = 0;
-
-		ctxt->ops->get_msr(ctxt, MSR_EFER, &efer);
-		if (efer & EFER_LMA) {
-			u64 maxphyaddr;
-			u32 eax, ebx, ecx, edx;
-
-			eax = 0x80000008;
-			ecx = 0;
-			if (ctxt->ops->get_cpuid(ctxt, &eax, &ebx, &ecx,
-						 &edx, true))
-				maxphyaddr = eax & 0xff;
-			else
-				maxphyaddr = 36;
-			rsvd = rsvd_bits(maxphyaddr, 63);
-			if (ctxt->ops->get_cr(ctxt, 4) & X86_CR4_PCIDE)
-				rsvd &= ~X86_CR3_PCID_NOFLUSH;
-		}
-
-		if (new_val & rsvd)
-			return emulate_gp(ctxt, 0);
-
-		break;
-		}
-	case 4: {
-		ctxt->ops->get_msr(ctxt, MSR_EFER, &efer);
-
-		if ((efer & EFER_LMA) && !(new_val & X86_CR4_PAE))
-			return emulate_gp(ctxt, 0);
-
-		break;
-		}
-	}
-
-	return X86EMUL_CONTINUE;
-}
-
 static int check_dr7_gd(struct x86_emulate_ctxt *ctxt)
 {
 	unsigned long dr7;
@@ -4841,10 +4767,10 @@ static const struct opcode twobyte_table[256] = {
 	D(ImplicitOps | ModRM | SrcMem | NoAccess), /* 8 * reserved NOP */
 	D(ImplicitOps | ModRM | SrcMem | NoAccess), /* NOP + 7 * reserved NOP */
 	/* 0x20 - 0x2F */
-	DIP(ModRM | DstMem | Priv | Op3264 | NoMod, cr_read, check_cr_read),
+	DIP(ModRM | DstMem | Priv | Op3264 | NoMod, cr_read, check_cr_access),
 	DIP(ModRM | DstMem | Priv | Op3264 | NoMod, dr_read, check_dr_read),
 	IIP(ModRM | SrcMem | Priv | Op3264 | NoMod, em_cr_write, cr_write,
-						check_cr_write),
+						check_cr_access),
 	IIP(ModRM | SrcMem | Priv | Op3264 | NoMod, em_dr_write, dr_write,
 						check_dr_write),
 	N, N, N, N,
diff --git a/arch/x86/kvm/kvm_cache_regs.h b/arch/x86/kvm/kvm_cache_regs.h
index 2e11da2f5621..3db5c42c9ecd 100644
--- a/arch/x86/kvm/kvm_cache_regs.h
+++ b/arch/x86/kvm/kvm_cache_regs.h
@@ -62,7 +62,12 @@ static inline void kvm_register_mark_dirty(struct kvm_vcpu *vcpu,
 	__set_bit(reg, (unsigned long *)&vcpu->arch.regs_dirty);
 }
 
-static inline unsigned long kvm_register_read(struct kvm_vcpu *vcpu, int reg)
+/*
+ * The "raw" register helpers are only for cases where the full 64 bits of a
+ * register are read/written irrespective of current vCPU mode.  In other words,
+ * odds are good you shouldn't be using the raw variants.
+ */
+static inline unsigned long kvm_register_read_raw(struct kvm_vcpu *vcpu, int reg)
 {
 	if (WARN_ON_ONCE((unsigned int)reg >= NR_VCPU_REGS))
 		return 0;
@@ -73,8 +78,8 @@ static inline unsigned long kvm_register_read(struct kvm_vcpu *vcpu, int reg)
 	return vcpu->arch.regs[reg];
 }
 
-static inline void kvm_register_write(struct kvm_vcpu *vcpu, int reg,
-				      unsigned long val)
+static inline void kvm_register_write_raw(struct kvm_vcpu *vcpu, int reg,
+					  unsigned long val)
 {
 	if (WARN_ON_ONCE((unsigned int)reg >= NR_VCPU_REGS))
 		return;
@@ -85,22 +90,22 @@ static inline void kvm_register_write(struct kvm_vcpu *vcpu, int reg,
 
 static inline unsigned long kvm_rip_read(struct kvm_vcpu *vcpu)
 {
-	return kvm_register_read(vcpu, VCPU_REGS_RIP);
+	return kvm_register_read_raw(vcpu, VCPU_REGS_RIP);
 }
 
 static inline void kvm_rip_write(struct kvm_vcpu *vcpu, unsigned long val)
 {
-	kvm_register_write(vcpu, VCPU_REGS_RIP, val);
+	kvm_register_write_raw(vcpu, VCPU_REGS_RIP, val);
 }
 
 static inline unsigned long kvm_rsp_read(struct kvm_vcpu *vcpu)
 {
-	return kvm_register_read(vcpu, VCPU_REGS_RSP);
+	return kvm_register_read_raw(vcpu, VCPU_REGS_RSP);
 }
 
 static inline void kvm_rsp_write(struct kvm_vcpu *vcpu, unsigned long val)
 {
-	kvm_register_write(vcpu, VCPU_REGS_RSP, val);
+	kvm_register_write_raw(vcpu, VCPU_REGS_RSP, val);
 }
 
 static inline u64 kvm_pdptr_read(struct kvm_vcpu *vcpu, int index)
diff --git a/arch/x86/kvm/lapic.c b/arch/x86/kvm/lapic.c
index cc369b9ad8f1..152591f9243a 100644
--- a/arch/x86/kvm/lapic.c
+++ b/arch/x86/kvm/lapic.c
@@ -296,6 +296,10 @@ static inline void apic_set_spiv(struct kvm_lapic *apic, u32 val)
 
 		atomic_set_release(&apic->vcpu->kvm->arch.apic_map_dirty, DIRTY);
 	}
+
+	/* Check if there are APF page ready requests pending */
+	if (enabled)
+		kvm_make_request(KVM_REQ_APF_READY, apic->vcpu);
 }
 
 static inline void kvm_apic_set_xapic_id(struct kvm_lapic *apic, u8 id)
@@ -2261,6 +2265,8 @@ void kvm_lapic_set_base(struct kvm_vcpu *vcpu, u64 value)
 		if (value & MSR_IA32_APICBASE_ENABLE) {
 			kvm_apic_set_xapic_id(apic, vcpu->vcpu_id);
 			static_branch_slow_dec_deferred(&apic_hw_disabled);
+			/* Check if there are APF page ready requests pending */
+			kvm_make_request(KVM_REQ_APF_READY, vcpu);
 		} else {
 			static_branch_inc(&apic_hw_disabled.key);
 			atomic_set_release(&apic->vcpu->kvm->arch.apic_map_dirty, DIRTY);
@@ -2869,7 +2875,7 @@ void kvm_apic_accept_events(struct kvm_vcpu *vcpu)
 		return;
 
 	if (is_guest_mode(vcpu)) {
-		r = kvm_x86_ops.nested_ops->check_events(vcpu);
+		r = kvm_check_nested_events(vcpu);
 		if (r < 0)
 			return;
 		/*
diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index c68bfc3e2402..88d0ed5225a4 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -59,7 +59,8 @@ static __always_inline u64 rsvd_bits(int s, int e)
 	return ((2ULL << (e - s)) - 1) << s;
 }
 
-void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 access_mask);
+void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 mmio_mask, u64 access_mask);
+void kvm_mmu_set_ept_masks(bool has_ad_bits, bool has_exec_only);
 
 void
 reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu, struct kvm_mmu *context);
@@ -73,6 +74,10 @@ bool kvm_can_do_async_pf(struct kvm_vcpu *vcpu);
 int kvm_handle_page_fault(struct kvm_vcpu *vcpu, u64 error_code,
 				u64 fault_address, char *insn, int insn_len);
 
+int kvm_mmu_load(struct kvm_vcpu *vcpu);
+void kvm_mmu_unload(struct kvm_vcpu *vcpu);
+void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu);
+
 static inline int kvm_mmu_reload(struct kvm_vcpu *vcpu)
 {
 	if (likely(vcpu->arch.mmu->root_hpa != INVALID_PAGE))
@@ -102,8 +107,8 @@ static inline void kvm_mmu_load_pgd(struct kvm_vcpu *vcpu)
 	if (!VALID_PAGE(root_hpa))
 		return;
 
-	static_call(kvm_x86_load_mmu_pgd)(vcpu, root_hpa | kvm_get_active_pcid(vcpu),
-				 vcpu->arch.mmu->shadow_root_level);
+	static_call(kvm_x86_load_mmu_pgd)(vcpu, root_hpa,
+					  vcpu->arch.mmu->shadow_root_level);
 }
 
 int kvm_tdp_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
@@ -124,7 +129,7 @@ static inline int kvm_mmu_do_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
  * write-protects guest page to sync the guest modification, b) another one is
  * used to sync dirty bitmap when we do KVM_GET_DIRTY_LOG. The differences
  * between these two sorts are:
- * 1) the first case clears SPTE_MMU_WRITEABLE bit.
+ * 1) the first case clears MMU-writable bit.
  * 2) the first case requires flushing tlb immediately avoiding corrupting
  *    shadow page table between all vcpus so it should be in the protection of
  *    mmu-lock. And the another case does not need to flush tlb until returning
@@ -135,17 +140,17 @@ static inline int kvm_mmu_do_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
  * So, there is the problem: the first case can meet the corrupted tlb caused
  * by another case which write-protects pages but without flush tlb
  * immediately. In order to making the first case be aware this problem we let
- * it flush tlb if we try to write-protect a spte whose SPTE_MMU_WRITEABLE bit
- * is set, it works since another case never touches SPTE_MMU_WRITEABLE bit.
+ * it flush tlb if we try to write-protect a spte whose MMU-writable bit
+ * is set, it works since another case never touches MMU-writable bit.
  *
  * Anyway, whenever a spte is updated (only permission and status bits are
- * changed) we need to check whether the spte with SPTE_MMU_WRITEABLE becomes
+ * changed) we need to check whether the spte with MMU-writable becomes
  * readonly, if that happens, we need to flush tlb. Fortunately,
  * mmu_spte_update() has already handled it perfectly.
  *
- * The rules to use SPTE_MMU_WRITEABLE and PT_WRITABLE_MASK:
+ * The rules to use MMU-writable and PT_WRITABLE_MASK:
  * - if we want to see if it has writable tlb entry or if the spte can be
- *   writable on the mmu mapping, check SPTE_MMU_WRITEABLE, this is the most
+ *   writable on the mmu mapping, check MMU-writable, this is the most
  *   case, otherwise
  * - if we fix page fault on the spte or do write-protection by dirty logging,
  *   check PT_WRITABLE_MASK.
diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 62b1729277ef..4b3ee244ebe0 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -48,6 +48,7 @@
 #include <asm/memtype.h>
 #include <asm/cmpxchg.h>
 #include <asm/io.h>
+#include <asm/set_memory.h>
 #include <asm/vmx.h>
 #include <asm/kvm_page_track.h>
 #include "trace.h"
@@ -215,10 +216,10 @@ bool is_nx_huge_page_enabled(void)
 static void mark_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, u64 gfn,
 			   unsigned int access)
 {
-	u64 mask = make_mmio_spte(vcpu, gfn, access);
+	u64 spte = make_mmio_spte(vcpu, gfn, access);
 
-	trace_mark_mmio_spte(sptep, gfn, mask);
-	mmu_spte_set(sptep, mask);
+	trace_mark_mmio_spte(sptep, gfn, spte);
+	mmu_spte_set(sptep, spte);
 }
 
 static gfn_t get_mmio_spte_gfn(u64 spte)
@@ -236,17 +237,6 @@ static unsigned get_mmio_spte_access(u64 spte)
 	return spte & shadow_mmio_access_mask;
 }
 
-static bool set_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
-			  kvm_pfn_t pfn, unsigned int access)
-{
-	if (unlikely(is_noslot_pfn(pfn))) {
-		mark_mmio_spte(vcpu, sptep, gfn, access);
-		return true;
-	}
-
-	return false;
-}
-
 static bool check_mmio_spte(struct kvm_vcpu *vcpu, u64 spte)
 {
 	u64 kvm_gen, spte_gen, gen;
@@ -725,8 +715,7 @@ static void kvm_mmu_page_set_gfn(struct kvm_mmu_page *sp, int index, gfn_t gfn)
  * handling slots that are not large page aligned.
  */
 static struct kvm_lpage_info *lpage_info_slot(gfn_t gfn,
-					      struct kvm_memory_slot *slot,
-					      int level)
+		const struct kvm_memory_slot *slot, int level)
 {
 	unsigned long idx;
 
@@ -1118,7 +1107,7 @@ static bool spte_write_protect(u64 *sptep, bool pt_protect)
 	rmap_printk("spte %p %llx\n", sptep, *sptep);
 
 	if (pt_protect)
-		spte &= ~SPTE_MMU_WRITEABLE;
+		spte &= ~shadow_mmu_writable_mask;
 	spte = spte & ~PT_WRITABLE_MASK;
 
 	return mmu_spte_update(sptep, spte);
@@ -1308,26 +1297,25 @@ static bool kvm_zap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
 	return flush;
 }
 
-static int kvm_unmap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			   struct kvm_memory_slot *slot, gfn_t gfn, int level,
-			   unsigned long data)
+static bool kvm_unmap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			    struct kvm_memory_slot *slot, gfn_t gfn, int level,
+			    pte_t unused)
 {
 	return kvm_zap_rmapp(kvm, rmap_head, slot);
 }
 
-static int kvm_set_pte_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			     struct kvm_memory_slot *slot, gfn_t gfn, int level,
-			     unsigned long data)
+static bool kvm_set_pte_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			      struct kvm_memory_slot *slot, gfn_t gfn, int level,
+			      pte_t pte)
 {
 	u64 *sptep;
 	struct rmap_iterator iter;
 	int need_flush = 0;
 	u64 new_spte;
-	pte_t *ptep = (pte_t *)data;
 	kvm_pfn_t new_pfn;
 
-	WARN_ON(pte_huge(*ptep));
-	new_pfn = pte_pfn(*ptep);
+	WARN_ON(pte_huge(pte));
+	new_pfn = pte_pfn(pte);
 
 restart:
 	for_each_rmap_spte(rmap_head, &iter, sptep) {
@@ -1336,7 +1324,7 @@ restart:
 
 		need_flush = 1;
 
-		if (pte_write(*ptep)) {
+		if (pte_write(pte)) {
 			pte_list_remove(rmap_head, sptep);
 			goto restart;
 		} else {
@@ -1424,93 +1412,52 @@ static void slot_rmap_walk_next(struct slot_rmap_walk_iterator *iterator)
 	     slot_rmap_walk_okay(_iter_);				\
 	     slot_rmap_walk_next(_iter_))
 
-static __always_inline int
-kvm_handle_hva_range(struct kvm *kvm,
-		     unsigned long start,
-		     unsigned long end,
-		     unsigned long data,
-		     int (*handler)(struct kvm *kvm,
-				    struct kvm_rmap_head *rmap_head,
-				    struct kvm_memory_slot *slot,
-				    gfn_t gfn,
-				    int level,
-				    unsigned long data))
+typedef bool (*rmap_handler_t)(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			       struct kvm_memory_slot *slot, gfn_t gfn,
+			       int level, pte_t pte);
+
+static __always_inline bool kvm_handle_gfn_range(struct kvm *kvm,
+						 struct kvm_gfn_range *range,
+						 rmap_handler_t handler)
 {
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
 	struct slot_rmap_walk_iterator iterator;
-	int ret = 0;
-	int i;
-
-	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
-		slots = __kvm_memslots(kvm, i);
-		kvm_for_each_memslot(memslot, slots) {
-			unsigned long hva_start, hva_end;
-			gfn_t gfn_start, gfn_end;
+	bool ret = false;
 
-			hva_start = max(start, memslot->userspace_addr);
-			hva_end = min(end, memslot->userspace_addr +
-				      (memslot->npages << PAGE_SHIFT));
-			if (hva_start >= hva_end)
-				continue;
-			/*
-			 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-			 * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-			 */
-			gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-			gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
-
-			for_each_slot_rmap_range(memslot, PG_LEVEL_4K,
-						 KVM_MAX_HUGEPAGE_LEVEL,
-						 gfn_start, gfn_end - 1,
-						 &iterator)
-				ret |= handler(kvm, iterator.rmap, memslot,
-					       iterator.gfn, iterator.level, data);
-		}
-	}
+	for_each_slot_rmap_range(range->slot, PG_LEVEL_4K, KVM_MAX_HUGEPAGE_LEVEL,
+				 range->start, range->end - 1, &iterator)
+		ret |= handler(kvm, iterator.rmap, range->slot, iterator.gfn,
+			       iterator.level, range->pte);
 
 	return ret;
 }
 
-static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
-			  unsigned long data,
-			  int (*handler)(struct kvm *kvm,
-					 struct kvm_rmap_head *rmap_head,
-					 struct kvm_memory_slot *slot,
-					 gfn_t gfn, int level,
-					 unsigned long data))
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_handle_hva_range(kvm, hva, hva + 1, data, handler);
-}
-
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-			unsigned flags)
-{
-	int r;
+	bool flush;
 
-	r = kvm_handle_hva_range(kvm, start, end, 0, kvm_unmap_rmapp);
+	flush = kvm_handle_gfn_range(kvm, range, kvm_unmap_rmapp);
 
 	if (is_tdp_mmu_enabled(kvm))
-		r |= kvm_tdp_mmu_zap_hva_range(kvm, start, end);
+		flush |= kvm_tdp_mmu_unmap_gfn_range(kvm, range, flush);
 
-	return r;
+	return flush;
 }
 
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	int r;
+	bool flush;
 
-	r = kvm_handle_hva(kvm, hva, (unsigned long)&pte, kvm_set_pte_rmapp);
+	flush = kvm_handle_gfn_range(kvm, range, kvm_set_pte_rmapp);
 
 	if (is_tdp_mmu_enabled(kvm))
-		r |= kvm_tdp_mmu_set_spte_hva(kvm, hva, &pte);
+		flush |= kvm_tdp_mmu_set_spte_gfn(kvm, range);
 
-	return r;
+	return flush;
 }
 
-static int kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			 struct kvm_memory_slot *slot, gfn_t gfn, int level,
-			 unsigned long data)
+static bool kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			  struct kvm_memory_slot *slot, gfn_t gfn, int level,
+			  pte_t unused)
 {
 	u64 *sptep;
 	struct rmap_iterator iter;
@@ -1519,13 +1466,12 @@ static int kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
 	for_each_rmap_spte(rmap_head, &iter, sptep)
 		young |= mmu_spte_age(sptep);
 
-	trace_kvm_age_page(gfn, level, slot, young);
 	return young;
 }
 
-static int kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			      struct kvm_memory_slot *slot, gfn_t gfn,
-			      int level, unsigned long data)
+static bool kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			       struct kvm_memory_slot *slot, gfn_t gfn,
+			       int level, pte_t unused)
 {
 	u64 *sptep;
 	struct rmap_iterator iter;
@@ -1547,29 +1493,31 @@ static void rmap_recycle(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 
 	rmap_head = gfn_to_rmap(vcpu->kvm, gfn, sp);
 
-	kvm_unmap_rmapp(vcpu->kvm, rmap_head, NULL, gfn, sp->role.level, 0);
+	kvm_unmap_rmapp(vcpu->kvm, rmap_head, NULL, gfn, sp->role.level, __pte(0));
 	kvm_flush_remote_tlbs_with_address(vcpu->kvm, sp->gfn,
 			KVM_PAGES_PER_HPAGE(sp->role.level));
 }
 
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	int young = false;
+	bool young;
+
+	young = kvm_handle_gfn_range(kvm, range, kvm_age_rmapp);
 
-	young = kvm_handle_hva_range(kvm, start, end, 0, kvm_age_rmapp);
 	if (is_tdp_mmu_enabled(kvm))
-		young |= kvm_tdp_mmu_age_hva_range(kvm, start, end);
+		young |= kvm_tdp_mmu_age_gfn_range(kvm, range);
 
 	return young;
 }
 
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	int young = false;
+	bool young;
+
+	young = kvm_handle_gfn_range(kvm, range, kvm_test_age_rmapp);
 
-	young = kvm_handle_hva(kvm, hva, 0, kvm_test_age_rmapp);
 	if (is_tdp_mmu_enabled(kvm))
-		young |= kvm_tdp_mmu_test_age_hva(kvm, hva);
+		young |= kvm_tdp_mmu_test_age_gfn(kvm, range);
 
 	return young;
 }
@@ -2421,6 +2369,15 @@ static int make_mmu_pages_available(struct kvm_vcpu *vcpu)
 
 	kvm_mmu_zap_oldest_mmu_pages(vcpu->kvm, KVM_REFILL_PAGES - avail);
 
+	/*
+	 * Note, this check is intentionally soft, it only guarantees that one
+	 * page is available, while the caller may end up allocating as many as
+	 * four pages, e.g. for PAE roots or for 5-level paging.  Temporarily
+	 * exceeding the (arbitrary by default) limit will not harm the host,
+	 * being too agressive may unnecessarily kill the guest, and getting an
+	 * exact count is far more trouble than it's worth, especially in the
+	 * page fault paths.
+	 */
 	if (!kvm_mmu_available_pages(vcpu->kvm))
 		return -ENOSPC;
 	return 0;
@@ -2561,9 +2518,6 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 	struct kvm_mmu_page *sp;
 	int ret;
 
-	if (set_mmio_spte(vcpu, sptep, gfn, pfn, pte_access))
-		return 0;
-
 	sp = sptep_to_sp(sptep);
 
 	ret = make_spte(vcpu, pte_access, level, gfn, pfn, *sptep, speculative,
@@ -2593,6 +2547,11 @@ static int mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 	pgprintk("%s: spte %llx write_fault %d gfn %llx\n", __func__,
 		 *sptep, write_fault, gfn);
 
+	if (unlikely(is_noslot_pfn(pfn))) {
+		mark_mmio_spte(vcpu, sptep, gfn, pte_access);
+		return RET_PF_EMULATE;
+	}
+
 	if (is_shadow_present_pte(*sptep)) {
 		/*
 		 * If we overwrite a PTE page pointer with a 2MB PMD, unlink
@@ -2626,9 +2585,6 @@ static int mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 		kvm_flush_remote_tlbs_with_address(vcpu->kvm, gfn,
 				KVM_PAGES_PER_HPAGE(level));
 
-	if (unlikely(is_mmio_spte(*sptep)))
-		ret = RET_PF_EMULATE;
-
 	/*
 	 * The fault is fully spurious if and only if the new SPTE and old SPTE
 	 * are identical, and emulation is not required.
@@ -2745,7 +2701,7 @@ static void direct_pte_prefetch(struct kvm_vcpu *vcpu, u64 *sptep)
 }
 
 static int host_pfn_mapping_level(struct kvm *kvm, gfn_t gfn, kvm_pfn_t pfn,
-				  struct kvm_memory_slot *slot)
+				  const struct kvm_memory_slot *slot)
 {
 	unsigned long hva;
 	pte_t *pte;
@@ -2771,8 +2727,9 @@ static int host_pfn_mapping_level(struct kvm *kvm, gfn_t gfn, kvm_pfn_t pfn,
 	return level;
 }
 
-int kvm_mmu_max_mapping_level(struct kvm *kvm, struct kvm_memory_slot *slot,
-			      gfn_t gfn, kvm_pfn_t pfn, int max_level)
+int kvm_mmu_max_mapping_level(struct kvm *kvm,
+			      const struct kvm_memory_slot *slot, gfn_t gfn,
+			      kvm_pfn_t pfn, int max_level)
 {
 	struct kvm_lpage_info *linfo;
 
@@ -2946,9 +2903,19 @@ static bool handle_abnormal_pfn(struct kvm_vcpu *vcpu, gva_t gva, gfn_t gfn,
 		return true;
 	}
 
-	if (unlikely(is_noslot_pfn(pfn)))
+	if (unlikely(is_noslot_pfn(pfn))) {
 		vcpu_cache_mmio_info(vcpu, gva, gfn,
 				     access & shadow_mmio_access_mask);
+		/*
+		 * If MMIO caching is disabled, emulate immediately without
+		 * touching the shadow page tables as attempting to install an
+		 * MMIO SPTE will just be an expensive nop.
+		 */
+		if (unlikely(!shadow_mmio_value)) {
+			*ret_val = RET_PF_EMULATE;
+			return true;
+		}
+	}
 
 	return false;
 }
@@ -3061,6 +3028,9 @@ static int fast_page_fault(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa,
 			if (!is_shadow_present_pte(spte))
 				break;
 
+		if (!is_shadow_present_pte(spte))
+			break;
+
 		sp = sptep_to_sp(iterator.sptep);
 		if (!is_last_spte(spte, sp->role.level))
 			break;
@@ -3150,12 +3120,10 @@ static void mmu_free_root_page(struct kvm *kvm, hpa_t *root_hpa,
 
 	sp = to_shadow_page(*root_hpa & PT64_BASE_ADDR_MASK);
 
-	if (kvm_mmu_put_root(kvm, sp)) {
-		if (is_tdp_mmu_page(sp))
-			kvm_tdp_mmu_free_root(kvm, sp);
-		else if (sp->role.invalid)
-			kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
-	}
+	if (is_tdp_mmu_page(sp))
+		kvm_tdp_mmu_put_root(kvm, sp, false);
+	else if (!--sp->root_count && sp->role.invalid)
+		kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
 
 	*root_hpa = INVALID_PAGE;
 }
@@ -3193,14 +3161,17 @@ void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 		if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
 		    (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
 			mmu_free_root_page(kvm, &mmu->root_hpa, &invalid_list);
-		} else {
-			for (i = 0; i < 4; ++i)
-				if (mmu->pae_root[i] != 0)
-					mmu_free_root_page(kvm,
-							   &mmu->pae_root[i],
-							   &invalid_list);
-			mmu->root_hpa = INVALID_PAGE;
+		} else if (mmu->pae_root) {
+			for (i = 0; i < 4; ++i) {
+				if (!IS_VALID_PAE_ROOT(mmu->pae_root[i]))
+					continue;
+
+				mmu_free_root_page(kvm, &mmu->pae_root[i],
+						   &invalid_list);
+				mmu->pae_root[i] = INVALID_PAE_ROOT;
+			}
 		}
+		mmu->root_hpa = INVALID_PAGE;
 		mmu->root_pgd = 0;
 	}
 
@@ -3226,155 +3197,208 @@ static hpa_t mmu_alloc_root(struct kvm_vcpu *vcpu, gfn_t gfn, gva_t gva,
 {
 	struct kvm_mmu_page *sp;
 
-	write_lock(&vcpu->kvm->mmu_lock);
-
-	if (make_mmu_pages_available(vcpu)) {
-		write_unlock(&vcpu->kvm->mmu_lock);
-		return INVALID_PAGE;
-	}
 	sp = kvm_mmu_get_page(vcpu, gfn, gva, level, direct, ACC_ALL);
 	++sp->root_count;
 
-	write_unlock(&vcpu->kvm->mmu_lock);
 	return __pa(sp->spt);
 }
 
 static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
 {
-	u8 shadow_root_level = vcpu->arch.mmu->shadow_root_level;
+	struct kvm_mmu *mmu = vcpu->arch.mmu;
+	u8 shadow_root_level = mmu->shadow_root_level;
 	hpa_t root;
 	unsigned i;
+	int r;
+
+	write_lock(&vcpu->kvm->mmu_lock);
+	r = make_mmu_pages_available(vcpu);
+	if (r < 0)
+		goto out_unlock;
 
 	if (is_tdp_mmu_enabled(vcpu->kvm)) {
 		root = kvm_tdp_mmu_get_vcpu_root_hpa(vcpu);
-
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->root_hpa = root;
+		mmu->root_hpa = root;
 	} else if (shadow_root_level >= PT64_ROOT_4LEVEL) {
-		root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level,
-				      true);
-
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->root_hpa = root;
+		root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level, true);
+		mmu->root_hpa = root;
 	} else if (shadow_root_level == PT32E_ROOT_LEVEL) {
+		if (WARN_ON_ONCE(!mmu->pae_root)) {
+			r = -EIO;
+			goto out_unlock;
+		}
+
 		for (i = 0; i < 4; ++i) {
-			MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
+			WARN_ON_ONCE(IS_VALID_PAE_ROOT(mmu->pae_root[i]));
 
 			root = mmu_alloc_root(vcpu, i << (30 - PAGE_SHIFT),
 					      i << 30, PT32_ROOT_LEVEL, true);
-			if (!VALID_PAGE(root))
-				return -ENOSPC;
-			vcpu->arch.mmu->pae_root[i] = root | PT_PRESENT_MASK;
+			mmu->pae_root[i] = root | PT_PRESENT_MASK |
+					   shadow_me_mask;
 		}
-		vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
-	} else
-		BUG();
+		mmu->root_hpa = __pa(mmu->pae_root);
+	} else {
+		WARN_ONCE(1, "Bad TDP root level = %d\n", shadow_root_level);
+		r = -EIO;
+		goto out_unlock;
+	}
 
 	/* root_pgd is ignored for direct MMUs. */
-	vcpu->arch.mmu->root_pgd = 0;
-
-	return 0;
+	mmu->root_pgd = 0;
+out_unlock:
+	write_unlock(&vcpu->kvm->mmu_lock);
+	return r;
 }
 
 static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
 {
-	u64 pdptr, pm_mask;
+	struct kvm_mmu *mmu = vcpu->arch.mmu;
+	u64 pdptrs[4], pm_mask;
 	gfn_t root_gfn, root_pgd;
 	hpa_t root;
-	int i;
+	unsigned i;
+	int r;
 
-	root_pgd = vcpu->arch.mmu->get_guest_pgd(vcpu);
+	root_pgd = mmu->get_guest_pgd(vcpu);
 	root_gfn = root_pgd >> PAGE_SHIFT;
 
 	if (mmu_check_root(vcpu, root_gfn))
 		return 1;
 
 	/*
+	 * On SVM, reading PDPTRs might access guest memory, which might fault
+	 * and thus might sleep.  Grab the PDPTRs before acquiring mmu_lock.
+	 */
+	if (mmu->root_level == PT32E_ROOT_LEVEL) {
+		for (i = 0; i < 4; ++i) {
+			pdptrs[i] = mmu->get_pdptr(vcpu, i);
+			if (!(pdptrs[i] & PT_PRESENT_MASK))
+				continue;
+
+			if (mmu_check_root(vcpu, pdptrs[i] >> PAGE_SHIFT))
+				return 1;
+		}
+	}
+
+	write_lock(&vcpu->kvm->mmu_lock);
+	r = make_mmu_pages_available(vcpu);
+	if (r < 0)
+		goto out_unlock;
+
+	/*
 	 * Do we shadow a long mode page table? If so we need to
 	 * write-protect the guests page table root.
 	 */
-	if (vcpu->arch.mmu->root_level >= PT64_ROOT_4LEVEL) {
-		MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->root_hpa));
-
+	if (mmu->root_level >= PT64_ROOT_4LEVEL) {
 		root = mmu_alloc_root(vcpu, root_gfn, 0,
-				      vcpu->arch.mmu->shadow_root_level, false);
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->root_hpa = root;
+				      mmu->shadow_root_level, false);
+		mmu->root_hpa = root;
 		goto set_root_pgd;
 	}
 
+	if (WARN_ON_ONCE(!mmu->pae_root)) {
+		r = -EIO;
+		goto out_unlock;
+	}
+
 	/*
 	 * We shadow a 32 bit page table. This may be a legacy 2-level
 	 * or a PAE 3-level page table. In either case we need to be aware that
 	 * the shadow page table may be a PAE or a long mode page table.
 	 */
-	pm_mask = PT_PRESENT_MASK;
-	if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL)
+	pm_mask = PT_PRESENT_MASK | shadow_me_mask;
+	if (mmu->shadow_root_level == PT64_ROOT_4LEVEL) {
 		pm_mask |= PT_ACCESSED_MASK | PT_WRITABLE_MASK | PT_USER_MASK;
 
+		if (WARN_ON_ONCE(!mmu->lm_root)) {
+			r = -EIO;
+			goto out_unlock;
+		}
+
+		mmu->lm_root[0] = __pa(mmu->pae_root) | pm_mask;
+	}
+
 	for (i = 0; i < 4; ++i) {
-		MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
-		if (vcpu->arch.mmu->root_level == PT32E_ROOT_LEVEL) {
-			pdptr = vcpu->arch.mmu->get_pdptr(vcpu, i);
-			if (!(pdptr & PT_PRESENT_MASK)) {
-				vcpu->arch.mmu->pae_root[i] = 0;
+		WARN_ON_ONCE(IS_VALID_PAE_ROOT(mmu->pae_root[i]));
+
+		if (mmu->root_level == PT32E_ROOT_LEVEL) {
+			if (!(pdptrs[i] & PT_PRESENT_MASK)) {
+				mmu->pae_root[i] = INVALID_PAE_ROOT;
 				continue;
 			}
-			root_gfn = pdptr >> PAGE_SHIFT;
-			if (mmu_check_root(vcpu, root_gfn))
-				return 1;
+			root_gfn = pdptrs[i] >> PAGE_SHIFT;
 		}
 
 		root = mmu_alloc_root(vcpu, root_gfn, i << 30,
 				      PT32_ROOT_LEVEL, false);
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->pae_root[i] = root | pm_mask;
+		mmu->pae_root[i] = root | pm_mask;
 	}
-	vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
+
+	if (mmu->shadow_root_level == PT64_ROOT_4LEVEL)
+		mmu->root_hpa = __pa(mmu->lm_root);
+	else
+		mmu->root_hpa = __pa(mmu->pae_root);
+
+set_root_pgd:
+	mmu->root_pgd = root_pgd;
+out_unlock:
+	write_unlock(&vcpu->kvm->mmu_lock);
+
+	return 0;
+}
+
+static int mmu_alloc_special_roots(struct kvm_vcpu *vcpu)
+{
+	struct kvm_mmu *mmu = vcpu->arch.mmu;
+	u64 *lm_root, *pae_root;
 
 	/*
-	 * If we shadow a 32 bit page table with a long mode page
-	 * table we enter this path.
+	 * When shadowing 32-bit or PAE NPT with 64-bit NPT, the PML4 and PDP
+	 * tables are allocated and initialized at root creation as there is no
+	 * equivalent level in the guest's NPT to shadow.  Allocate the tables
+	 * on demand, as running a 32-bit L1 VMM on 64-bit KVM is very rare.
 	 */
-	if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL) {
-		if (vcpu->arch.mmu->lm_root == NULL) {
-			/*
-			 * The additional page necessary for this is only
-			 * allocated on demand.
-			 */
+	if (mmu->direct_map || mmu->root_level >= PT64_ROOT_4LEVEL ||
+	    mmu->shadow_root_level < PT64_ROOT_4LEVEL)
+		return 0;
 
-			u64 *lm_root;
+	/*
+	 * This mess only works with 4-level paging and needs to be updated to
+	 * work with 5-level paging.
+	 */
+	if (WARN_ON_ONCE(mmu->shadow_root_level != PT64_ROOT_4LEVEL))
+		return -EIO;
 
-			lm_root = (void*)get_zeroed_page(GFP_KERNEL_ACCOUNT);
-			if (lm_root == NULL)
-				return 1;
+	if (mmu->pae_root && mmu->lm_root)
+		return 0;
 
-			lm_root[0] = __pa(vcpu->arch.mmu->pae_root) | pm_mask;
+	/*
+	 * The special roots should always be allocated in concert.  Yell and
+	 * bail if KVM ends up in a state where only one of the roots is valid.
+	 */
+	if (WARN_ON_ONCE(!tdp_enabled || mmu->pae_root || mmu->lm_root))
+		return -EIO;
 
-			vcpu->arch.mmu->lm_root = lm_root;
-		}
+	/*
+	 * Unlike 32-bit NPT, the PDP table doesn't need to be in low mem, and
+	 * doesn't need to be decrypted.
+	 */
+	pae_root = (void *)get_zeroed_page(GFP_KERNEL_ACCOUNT);
+	if (!pae_root)
+		return -ENOMEM;
 
-		vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->lm_root);
+	lm_root = (void *)get_zeroed_page(GFP_KERNEL_ACCOUNT);
+	if (!lm_root) {
+		free_page((unsigned long)pae_root);
+		return -ENOMEM;
 	}
 
-set_root_pgd:
-	vcpu->arch.mmu->root_pgd = root_pgd;
+	mmu->pae_root = pae_root;
+	mmu->lm_root = lm_root;
 
 	return 0;
 }
 
-static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
-{
-	if (vcpu->arch.mmu->direct_map)
-		return mmu_alloc_direct_roots(vcpu);
-	else
-		return mmu_alloc_shadow_roots(vcpu);
-}
-
 void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
 {
 	int i;
@@ -3422,7 +3446,7 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
 	for (i = 0; i < 4; ++i) {
 		hpa_t root = vcpu->arch.mmu->pae_root[i];
 
-		if (root && VALID_PAGE(root)) {
+		if (IS_VALID_PAE_ROOT(root)) {
 			root &= PT64_BASE_ADDR_MASK;
 			sp = to_shadow_page(root);
 			mmu_sync_children(vcpu, sp);
@@ -3554,11 +3578,12 @@ static bool get_mmio_spte(struct kvm_vcpu *vcpu, u64 addr, u64 *sptep)
 			    __is_rsvd_bits_set(rsvd_check, sptes[level], level);
 
 	if (reserved) {
-		pr_err("%s: detect reserved bits on spte, addr 0x%llx, dump hierarchy:\n",
+		pr_err("%s: reserved bits set on MMU-present spte, addr 0x%llx, hierarchy:\n",
 		       __func__, addr);
 		for (level = root; level >= leaf; level--)
-			pr_err("------ spte 0x%llx level %d.\n",
-			       sptes[level], level);
+			pr_err("------ spte = 0x%llx level = %d, rsvd bits = 0x%llx",
+			       sptes[level], level,
+			       rsvd_check->rsvd_bits_mask[(sptes[level] >> 7) & 1][level-1]);
 	}
 
 	return reserved;
@@ -3653,6 +3678,14 @@ static bool try_async_pf(struct kvm_vcpu *vcpu, bool prefault, gfn_t gfn,
 	struct kvm_memory_slot *slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
 	bool async;
 
+	/*
+	 * Retry the page fault if the gfn hit a memslot that is being deleted
+	 * or moved.  This ensures any existing SPTEs for the old memslot will
+	 * be zapped before KVM inserts a new MMIO SPTE for the gfn.
+	 */
+	if (slot && (slot->flags & KVM_MEMSLOT_INVALID))
+		return true;
+
 	/* Don't expose private memslots to L2. */
 	if (is_guest_mode(vcpu) && !kvm_is_visible_memslot(slot)) {
 		*pfn = KVM_PFN_NOSLOT;
@@ -4615,12 +4648,17 @@ void kvm_init_shadow_npt_mmu(struct kvm_vcpu *vcpu, u32 cr0, u32 cr4, u32 efer,
 	struct kvm_mmu *context = &vcpu->arch.guest_mmu;
 	union kvm_mmu_role new_role = kvm_calc_shadow_npt_root_page_role(vcpu);
 
-	context->shadow_root_level = new_role.base.level;
-
 	__kvm_mmu_new_pgd(vcpu, nested_cr3, new_role.base, false, false);
 
-	if (new_role.as_u64 != context->mmu_role.as_u64)
+	if (new_role.as_u64 != context->mmu_role.as_u64) {
 		shadow_mmu_init_context(vcpu, context, cr0, cr4, efer, new_role);
+
+		/*
+		 * Override the level set by the common init helper, nested TDP
+		 * always uses the host's TDP configuration.
+		 */
+		context->shadow_root_level = new_role.base.level;
+	}
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_npt_mmu);
 
@@ -4802,16 +4840,23 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
 	r = mmu_topup_memory_caches(vcpu, !vcpu->arch.mmu->direct_map);
 	if (r)
 		goto out;
-	r = mmu_alloc_roots(vcpu);
-	kvm_mmu_sync_roots(vcpu);
+	r = mmu_alloc_special_roots(vcpu);
+	if (r)
+		goto out;
+	if (vcpu->arch.mmu->direct_map)
+		r = mmu_alloc_direct_roots(vcpu);
+	else
+		r = mmu_alloc_shadow_roots(vcpu);
 	if (r)
 		goto out;
+
+	kvm_mmu_sync_roots(vcpu);
+
 	kvm_mmu_load_pgd(vcpu);
 	static_call(kvm_x86_tlb_flush_current)(vcpu);
 out:
 	return r;
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_load);
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
@@ -4820,7 +4865,6 @@ void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 	kvm_mmu_free_roots(vcpu, &vcpu->arch.guest_mmu, KVM_MMU_ROOTS_ALL);
 	WARN_ON(VALID_PAGE(vcpu->arch.guest_mmu.root_hpa));
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_unload);
 
 static bool need_remote_flush(u64 old, u64 new)
 {
@@ -5169,10 +5213,10 @@ typedef bool (*slot_level_handler) (struct kvm *kvm, struct kvm_rmap_head *rmap_
 static __always_inline bool
 slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
 			slot_level_handler fn, int start_level, int end_level,
-			gfn_t start_gfn, gfn_t end_gfn, bool lock_flush_tlb)
+			gfn_t start_gfn, gfn_t end_gfn, bool flush_on_yield,
+			bool flush)
 {
 	struct slot_rmap_walk_iterator iterator;
-	bool flush = false;
 
 	for_each_slot_rmap_range(memslot, start_level, end_level, start_gfn,
 			end_gfn, &iterator) {
@@ -5180,7 +5224,7 @@ slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
 			flush |= fn(kvm, iterator.rmap, memslot);
 
 		if (need_resched() || rwlock_needbreak(&kvm->mmu_lock)) {
-			if (flush && lock_flush_tlb) {
+			if (flush && flush_on_yield) {
 				kvm_flush_remote_tlbs_with_address(kvm,
 						start_gfn,
 						iterator.gfn - start_gfn + 1);
@@ -5190,36 +5234,32 @@ slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
 		}
 	}
 
-	if (flush && lock_flush_tlb) {
-		kvm_flush_remote_tlbs_with_address(kvm, start_gfn,
-						   end_gfn - start_gfn + 1);
-		flush = false;
-	}
-
 	return flush;
 }
 
 static __always_inline bool
 slot_handle_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
 		  slot_level_handler fn, int start_level, int end_level,
-		  bool lock_flush_tlb)
+		  bool flush_on_yield)
 {
 	return slot_handle_level_range(kvm, memslot, fn, start_level,
 			end_level, memslot->base_gfn,
 			memslot->base_gfn + memslot->npages - 1,
-			lock_flush_tlb);
+			flush_on_yield, false);
 }
 
 static __always_inline bool
 slot_handle_leaf(struct kvm *kvm, struct kvm_memory_slot *memslot,
-		 slot_level_handler fn, bool lock_flush_tlb)
+		 slot_level_handler fn, bool flush_on_yield)
 {
 	return slot_handle_level(kvm, memslot, fn, PG_LEVEL_4K,
-				 PG_LEVEL_4K, lock_flush_tlb);
+				 PG_LEVEL_4K, flush_on_yield);
 }
 
 static void free_mmu_pages(struct kvm_mmu *mmu)
 {
+	if (!tdp_enabled && mmu->pae_root)
+		set_memory_encrypted((unsigned long)mmu->pae_root, 1);
 	free_page((unsigned long)mmu->pae_root);
 	free_page((unsigned long)mmu->lm_root);
 }
@@ -5240,9 +5280,11 @@ static int __kvm_mmu_create(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
 	 * while the PDP table is a per-vCPU construct that's allocated at MMU
 	 * creation.  When emulating 32-bit mode, cr3 is only 32 bits even on
 	 * x86_64.  Therefore we need to allocate the PDP table in the first
-	 * 4GB of memory, which happens to fit the DMA32 zone.  Except for
-	 * SVM's 32-bit NPT support, TDP paging doesn't use PAE paging and can
-	 * skip allocating the PDP table.
+	 * 4GB of memory, which happens to fit the DMA32 zone.  TDP paging
+	 * generally doesn't use PAE paging and can skip allocating the PDP
+	 * table.  The main exception, handled here, is SVM's 32-bit NPT.  The
+	 * other exception is for shadowing L1's 32-bit or PAE NPT on 64-bit
+	 * KVM; that horror is handled on-demand by mmu_alloc_shadow_roots().
 	 */
 	if (tdp_enabled && kvm_mmu_get_tdp_level(vcpu) > PT32E_ROOT_LEVEL)
 		return 0;
@@ -5252,8 +5294,22 @@ static int __kvm_mmu_create(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
 		return -ENOMEM;
 
 	mmu->pae_root = page_address(page);
+
+	/*
+	 * CR3 is only 32 bits when PAE paging is used, thus it's impossible to
+	 * get the CPU to treat the PDPTEs as encrypted.  Decrypt the page so
+	 * that KVM's writes and the CPU's reads get along.  Note, this is
+	 * only necessary when using shadow paging, as 64-bit NPT can get at
+	 * the C-bit even when shadowing 32-bit NPT, and SME isn't supported
+	 * by 32-bit kernels (when KVM itself uses 32-bit NPT).
+	 */
+	if (!tdp_enabled)
+		set_memory_decrypted((unsigned long)mmu->pae_root, 1);
+	else
+		WARN_ON_ONCE(shadow_me_mask);
+
 	for (i = 0; i < 4; ++i)
-		mmu->pae_root[i] = INVALID_PAGE;
+		mmu->pae_root[i] = INVALID_PAE_ROOT;
 
 	return 0;
 }
@@ -5365,6 +5421,15 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
 	 */
 	kvm->arch.mmu_valid_gen = kvm->arch.mmu_valid_gen ? 0 : 1;
 
+	/* In order to ensure all threads see this change when
+	 * handling the MMU reload signal, this must happen in the
+	 * same critical section as kvm_reload_remote_mmus, and
+	 * before kvm_zap_obsolete_pages as kvm_zap_obsolete_pages
+	 * could drop the MMU lock and yield.
+	 */
+	if (is_tdp_mmu_enabled(kvm))
+		kvm_tdp_mmu_invalidate_all_roots(kvm);
+
 	/*
 	 * Notify all vcpus to reload its shadow page table and flush TLB.
 	 * Then all vcpus will switch to new shadow page table with the new
@@ -5377,10 +5442,13 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
 
 	kvm_zap_obsolete_pages(kvm);
 
-	if (is_tdp_mmu_enabled(kvm))
-		kvm_tdp_mmu_zap_all(kvm);
-
 	write_unlock(&kvm->mmu_lock);
+
+	if (is_tdp_mmu_enabled(kvm)) {
+		read_lock(&kvm->mmu_lock);
+		kvm_tdp_mmu_zap_invalidated_roots(kvm);
+		read_unlock(&kvm->mmu_lock);
+	}
 }
 
 static bool kvm_has_zapped_obsolete_pages(struct kvm *kvm)
@@ -5420,7 +5488,7 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
 	struct kvm_memslots *slots;
 	struct kvm_memory_slot *memslot;
 	int i;
-	bool flush;
+	bool flush = false;
 
 	write_lock(&kvm->mmu_lock);
 	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
@@ -5433,20 +5501,31 @@ void kvm_zap_gfn_range(struct kvm *kvm, gfn_t gfn_start, gfn_t gfn_end)
 			if (start >= end)
 				continue;
 
-			slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
-						PG_LEVEL_4K,
-						KVM_MAX_HUGEPAGE_LEVEL,
-						start, end - 1, true);
+			flush = slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
+							PG_LEVEL_4K,
+							KVM_MAX_HUGEPAGE_LEVEL,
+							start, end - 1, true, flush);
 		}
 	}
 
+	if (flush)
+		kvm_flush_remote_tlbs_with_address(kvm, gfn_start, gfn_end);
+
+	write_unlock(&kvm->mmu_lock);
+
 	if (is_tdp_mmu_enabled(kvm)) {
-		flush = kvm_tdp_mmu_zap_gfn_range(kvm, gfn_start, gfn_end);
+		flush = false;
+
+		read_lock(&kvm->mmu_lock);
+		for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
+			flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, gfn_start,
+							  gfn_end, flush, true);
 		if (flush)
-			kvm_flush_remote_tlbs(kvm);
-	}
+			kvm_flush_remote_tlbs_with_address(kvm, gfn_start,
+							   gfn_end);
 
-	write_unlock(&kvm->mmu_lock);
+		read_unlock(&kvm->mmu_lock);
+	}
 }
 
 static bool slot_rmap_write_protect(struct kvm *kvm,
@@ -5465,10 +5544,14 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
 	write_lock(&kvm->mmu_lock);
 	flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
 				start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
-	if (is_tdp_mmu_enabled(kvm))
-		flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, PG_LEVEL_4K);
 	write_unlock(&kvm->mmu_lock);
 
+	if (is_tdp_mmu_enabled(kvm)) {
+		read_lock(&kvm->mmu_lock);
+		flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, start_level);
+		read_unlock(&kvm->mmu_lock);
+	}
+
 	/*
 	 * We can flush all the TLBs out of the mmu lock without TLB
 	 * corruption since we just change the spte from writable to
@@ -5476,9 +5559,9 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
 	 * spte from present to present (changing the spte from present
 	 * to nonpresent will flush all the TLBs immediately), in other
 	 * words, the only case we care is mmu_spte_update() where we
-	 * have checked SPTE_HOST_WRITEABLE | SPTE_MMU_WRITEABLE
-	 * instead of PT_WRITABLE_MASK, that means it does not depend
-	 * on PT_WRITABLE_MASK anymore.
+	 * have checked Host-writable | MMU-writable instead of
+	 * PT_WRITABLE_MASK, that means it does not depend on PT_WRITABLE_MASK
+	 * anymore.
 	 */
 	if (flush)
 		kvm_arch_flush_remote_tlbs_memslot(kvm, memslot);
@@ -5529,21 +5612,32 @@ void kvm_mmu_zap_collapsible_sptes(struct kvm *kvm,
 {
 	/* FIXME: const-ify all uses of struct kvm_memory_slot.  */
 	struct kvm_memory_slot *slot = (struct kvm_memory_slot *)memslot;
+	bool flush;
 
 	write_lock(&kvm->mmu_lock);
-	slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
+	flush = slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
 
-	if (is_tdp_mmu_enabled(kvm))
-		kvm_tdp_mmu_zap_collapsible_sptes(kvm, slot);
+	if (flush)
+		kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
 	write_unlock(&kvm->mmu_lock);
+
+	if (is_tdp_mmu_enabled(kvm)) {
+		flush = false;
+
+		read_lock(&kvm->mmu_lock);
+		flush = kvm_tdp_mmu_zap_collapsible_sptes(kvm, slot, flush);
+		if (flush)
+			kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
+		read_unlock(&kvm->mmu_lock);
+	}
 }
 
 void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
-					struct kvm_memory_slot *memslot)
+					const struct kvm_memory_slot *memslot)
 {
 	/*
 	 * All current use cases for flushing the TLBs for a specific memslot
-	 * are related to dirty logging, and do the TLB flush out of mmu_lock.
+	 * related to dirty logging, and many do the TLB flush out of mmu_lock.
 	 * The interaction between the various operations on memslot must be
 	 * serialized by slots_locks to ensure the TLB flush from one operation
 	 * is observed by any other operation on the same memslot.
@@ -5560,10 +5654,14 @@ void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
 
 	write_lock(&kvm->mmu_lock);
 	flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty, false);
-	if (is_tdp_mmu_enabled(kvm))
-		flush |= kvm_tdp_mmu_clear_dirty_slot(kvm, memslot);
 	write_unlock(&kvm->mmu_lock);
 
+	if (is_tdp_mmu_enabled(kvm)) {
+		read_lock(&kvm->mmu_lock);
+		flush |= kvm_tdp_mmu_clear_dirty_slot(kvm, memslot);
+		read_unlock(&kvm->mmu_lock);
+	}
+
 	/*
 	 * It's also safe to flush TLBs out of mmu lock here as currently this
 	 * function is only used for dirty logging, in which case flushing TLB
@@ -5701,25 +5799,6 @@ static void mmu_destroy_caches(void)
 	kmem_cache_destroy(mmu_page_header_cache);
 }
 
-static void kvm_set_mmio_spte_mask(void)
-{
-	u64 mask;
-
-	/*
-	 * Set a reserved PA bit in MMIO SPTEs to generate page faults with
-	 * PFEC.RSVD=1 on MMIO accesses.  64-bit PTEs (PAE, x86-64, and EPT
-	 * paging) support a maximum of 52 bits of PA, i.e. if the CPU supports
-	 * 52-bit physical addresses then there are no reserved PA bits in the
-	 * PTEs and so the reserved PA approach must be disabled.
-	 */
-	if (shadow_phys_bits < 52)
-		mask = BIT_ULL(51) | PT_PRESENT_MASK;
-	else
-		mask = 0;
-
-	kvm_mmu_set_mmio_spte_mask(mask, ACC_WRITE_MASK | ACC_USER_MASK);
-}
-
 static bool get_nx_auto_mode(void)
 {
 	/* Return true when CPU has the bug, and mitigations are ON */
@@ -5785,8 +5864,6 @@ int kvm_mmu_module_init(void)
 
 	kvm_mmu_reset_all_pte_masks();
 
-	kvm_set_mmio_spte_mask();
-
 	pte_list_desc_cache = kmem_cache_create("pte_list_desc",
 					    sizeof(struct pte_list_desc),
 					    0, SLAB_ACCOUNT, NULL);
diff --git a/arch/x86/kvm/mmu/mmu_audit.c b/arch/x86/kvm/mmu/mmu_audit.c
index ced15fd58fde..cedc17b2f60e 100644
--- a/arch/x86/kvm/mmu/mmu_audit.c
+++ b/arch/x86/kvm/mmu/mmu_audit.c
@@ -70,7 +70,7 @@ static void mmu_spte_walk(struct kvm_vcpu *vcpu, inspect_spte_fn fn)
 	for (i = 0; i < 4; ++i) {
 		hpa_t root = vcpu->arch.mmu->pae_root[i];
 
-		if (root && VALID_PAGE(root)) {
+		if (IS_VALID_PAE_ROOT(root)) {
 			root &= PT64_BASE_ADDR_MASK;
 			sp = to_shadow_page(root);
 			__mmu_spte_walk(vcpu, sp, fn, 2);
diff --git a/arch/x86/kvm/mmu/mmu_internal.h b/arch/x86/kvm/mmu/mmu_internal.h
index 360983865398..d64ccb417c60 100644
--- a/arch/x86/kvm/mmu/mmu_internal.h
+++ b/arch/x86/kvm/mmu/mmu_internal.h
@@ -20,6 +20,16 @@ extern bool dbg;
 #define MMU_WARN_ON(x) do { } while (0)
 #endif
 
+/*
+ * Unlike regular MMU roots, PAE "roots", a.k.a. PDPTEs/PDPTRs, have a PRESENT
+ * bit, and thus are guaranteed to be non-zero when valid.  And, when a guest
+ * PDPTR is !PRESENT, its corresponding PAE root cannot be set to INVALID_PAGE,
+ * as the CPU would treat that as PRESENT PDPTR with reserved bits set.  Use
+ * '0' instead of INVALID_PAGE to indicate an invalid PAE root.
+ */
+#define INVALID_PAE_ROOT	0
+#define IS_VALID_PAE_ROOT(x)	(!!(x))
+
 struct kvm_mmu_page {
 	struct list_head link;
 	struct hlist_node hash_link;
@@ -40,7 +50,11 @@ struct kvm_mmu_page {
 	u64 *spt;
 	/* hold the gfn of each spte inside spt */
 	gfn_t *gfns;
-	int root_count;          /* Currently serving as active root */
+	/* Currently serving as active root */
+	union {
+		int root_count;
+		refcount_t tdp_mmu_root_count;
+	};
 	unsigned int unsync_children;
 	struct kvm_rmap_head parent_ptes; /* rmap pointers to parent sptes */
 	DECLARE_BITMAP(unsync_child_bitmap, 512);
@@ -78,9 +92,14 @@ static inline struct kvm_mmu_page *sptep_to_sp(u64 *sptep)
 	return to_shadow_page(__pa(sptep));
 }
 
+static inline int kvm_mmu_role_as_id(union kvm_mmu_page_role role)
+{
+	return role.smm ? 1 : 0;
+}
+
 static inline int kvm_mmu_page_as_id(struct kvm_mmu_page *sp)
 {
-	return sp->role.smm ? 1 : 0;
+	return kvm_mmu_role_as_id(sp->role);
 }
 
 static inline bool kvm_vcpu_ad_need_write_protect(struct kvm_vcpu *vcpu)
@@ -108,22 +127,6 @@ bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
 void kvm_flush_remote_tlbs_with_address(struct kvm *kvm,
 					u64 start_gfn, u64 pages);
 
-static inline void kvm_mmu_get_root(struct kvm *kvm, struct kvm_mmu_page *sp)
-{
-	BUG_ON(!sp->root_count);
-	lockdep_assert_held(&kvm->mmu_lock);
-
-	++sp->root_count;
-}
-
-static inline bool kvm_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *sp)
-{
-	lockdep_assert_held(&kvm->mmu_lock);
-	--sp->root_count;
-
-	return !sp->root_count;
-}
-
 /*
  * Return values of handle_mmio_page_fault, mmu.page_fault, and fast_page_fault().
  *
@@ -146,8 +149,9 @@ enum {
 #define SET_SPTE_NEED_REMOTE_TLB_FLUSH	BIT(1)
 #define SET_SPTE_SPURIOUS		BIT(2)
 
-int kvm_mmu_max_mapping_level(struct kvm *kvm, struct kvm_memory_slot *slot,
-			      gfn_t gfn, kvm_pfn_t pfn, int max_level);
+int kvm_mmu_max_mapping_level(struct kvm *kvm,
+			      const struct kvm_memory_slot *slot, gfn_t gfn,
+			      kvm_pfn_t pfn, int max_level);
 int kvm_mmu_hugepage_adjust(struct kvm_vcpu *vcpu, gfn_t gfn,
 			    int max_level, kvm_pfn_t *pfnp,
 			    bool huge_page_disallowed, int *req_level);
diff --git a/arch/x86/kvm/mmu/paging_tmpl.h b/arch/x86/kvm/mmu/paging_tmpl.h
index 55d7b473ac44..70b7e44e3035 100644
--- a/arch/x86/kvm/mmu/paging_tmpl.h
+++ b/arch/x86/kvm/mmu/paging_tmpl.h
@@ -503,6 +503,7 @@ error:
 #endif
 	walker->fault.address = addr;
 	walker->fault.nested_page_fault = mmu != vcpu->arch.walk_mmu;
+	walker->fault.async_page_fault = false;
 
 	trace_kvm_mmu_walker_error(walker->fault.error_code);
 	return 0;
@@ -1084,7 +1085,7 @@ static int FNAME(sync_page)(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
 
 		nr_present++;
 
-		host_writable = sp->spt[i] & SPTE_HOST_WRITEABLE;
+		host_writable = sp->spt[i] & shadow_host_writable_mask;
 
 		set_spte_ret |= set_spte(vcpu, &sp->spt[i],
 					 pte_access, PG_LEVEL_4K,
diff --git a/arch/x86/kvm/mmu/spte.c b/arch/x86/kvm/mmu/spte.c
index ef55f0bc4ccf..66d43cec0c31 100644
--- a/arch/x86/kvm/mmu/spte.c
+++ b/arch/x86/kvm/mmu/spte.c
@@ -16,13 +16,20 @@
 #include "spte.h"
 
 #include <asm/e820/api.h>
+#include <asm/vmx.h>
 
+static bool __read_mostly enable_mmio_caching = true;
+module_param_named(mmio_caching, enable_mmio_caching, bool, 0444);
+
+u64 __read_mostly shadow_host_writable_mask;
+u64 __read_mostly shadow_mmu_writable_mask;
 u64 __read_mostly shadow_nx_mask;
 u64 __read_mostly shadow_x_mask; /* mutual exclusive with nx_mask */
 u64 __read_mostly shadow_user_mask;
 u64 __read_mostly shadow_accessed_mask;
 u64 __read_mostly shadow_dirty_mask;
 u64 __read_mostly shadow_mmio_value;
+u64 __read_mostly shadow_mmio_mask;
 u64 __read_mostly shadow_mmio_access_mask;
 u64 __read_mostly shadow_present_mask;
 u64 __read_mostly shadow_me_mask;
@@ -38,7 +45,6 @@ static u64 generation_mmio_spte_mask(u64 gen)
 	u64 mask;
 
 	WARN_ON(gen & ~MMIO_SPTE_GEN_MASK);
-	BUILD_BUG_ON((MMIO_SPTE_GEN_HIGH_MASK | MMIO_SPTE_GEN_LOW_MASK) & SPTE_SPECIAL_MASK);
 
 	mask = (gen << MMIO_SPTE_GEN_LOW_SHIFT) & MMIO_SPTE_GEN_LOW_MASK;
 	mask |= (gen << MMIO_SPTE_GEN_HIGH_SHIFT) & MMIO_SPTE_GEN_HIGH_MASK;
@@ -48,16 +54,18 @@ static u64 generation_mmio_spte_mask(u64 gen)
 u64 make_mmio_spte(struct kvm_vcpu *vcpu, u64 gfn, unsigned int access)
 {
 	u64 gen = kvm_vcpu_memslots(vcpu)->generation & MMIO_SPTE_GEN_MASK;
-	u64 mask = generation_mmio_spte_mask(gen);
+	u64 spte = generation_mmio_spte_mask(gen);
 	u64 gpa = gfn << PAGE_SHIFT;
 
+	WARN_ON_ONCE(!shadow_mmio_value);
+
 	access &= shadow_mmio_access_mask;
-	mask |= shadow_mmio_value | access;
-	mask |= gpa | shadow_nonpresent_or_rsvd_mask;
-	mask |= (gpa & shadow_nonpresent_or_rsvd_mask)
+	spte |= shadow_mmio_value | access;
+	spte |= gpa | shadow_nonpresent_or_rsvd_mask;
+	spte |= (gpa & shadow_nonpresent_or_rsvd_mask)
 		<< SHADOW_NONPRESENT_OR_RSVD_MASK_LEN;
 
-	return mask;
+	return spte;
 }
 
 static bool kvm_is_mmio_pfn(kvm_pfn_t pfn)
@@ -86,13 +94,20 @@ int make_spte(struct kvm_vcpu *vcpu, unsigned int pte_access, int level,
 		     bool can_unsync, bool host_writable, bool ad_disabled,
 		     u64 *new_spte)
 {
-	u64 spte = 0;
+	u64 spte = SPTE_MMU_PRESENT_MASK;
 	int ret = 0;
 
 	if (ad_disabled)
-		spte |= SPTE_AD_DISABLED_MASK;
+		spte |= SPTE_TDP_AD_DISABLED_MASK;
 	else if (kvm_vcpu_ad_need_write_protect(vcpu))
-		spte |= SPTE_AD_WRPROT_ONLY_MASK;
+		spte |= SPTE_TDP_AD_WRPROT_ONLY_MASK;
+
+	/*
+	 * Bits 62:52 of PAE SPTEs are reserved.  WARN if said bits are set
+	 * if PAE paging may be employed (shadow paging or any 32-bit KVM).
+	 */
+	WARN_ON_ONCE((!tdp_enabled || !IS_ENABLED(CONFIG_X86_64)) &&
+		     (spte & SPTE_TDP_AD_MASK));
 
 	/*
 	 * For the EPT case, shadow_present_mask is 0 if hardware
@@ -124,7 +139,7 @@ int make_spte(struct kvm_vcpu *vcpu, unsigned int pte_access, int level,
 			kvm_is_mmio_pfn(pfn));
 
 	if (host_writable)
-		spte |= SPTE_HOST_WRITEABLE;
+		spte |= shadow_host_writable_mask;
 	else
 		pte_access &= ~ACC_WRITE_MASK;
 
@@ -134,7 +149,7 @@ int make_spte(struct kvm_vcpu *vcpu, unsigned int pte_access, int level,
 	spte |= (u64)pfn << PAGE_SHIFT;
 
 	if (pte_access & ACC_WRITE_MASK) {
-		spte |= PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE;
+		spte |= PT_WRITABLE_MASK | shadow_mmu_writable_mask;
 
 		/*
 		 * Optimization: for pte sync, if spte was writable the hash
@@ -150,7 +165,7 @@ int make_spte(struct kvm_vcpu *vcpu, unsigned int pte_access, int level,
 				 __func__, gfn);
 			ret |= SET_SPTE_WRITE_PROTECTED_PT;
 			pte_access &= ~ACC_WRITE_MASK;
-			spte &= ~(PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE);
+			spte &= ~(PT_WRITABLE_MASK | shadow_mmu_writable_mask);
 		}
 	}
 
@@ -161,19 +176,20 @@ int make_spte(struct kvm_vcpu *vcpu, unsigned int pte_access, int level,
 		spte = mark_spte_for_access_track(spte);
 
 out:
+	WARN_ON(is_mmio_spte(spte));
 	*new_spte = spte;
 	return ret;
 }
 
 u64 make_nonleaf_spte(u64 *child_pt, bool ad_disabled)
 {
-	u64 spte;
+	u64 spte = SPTE_MMU_PRESENT_MASK;
 
-	spte = __pa(child_pt) | shadow_present_mask | PT_WRITABLE_MASK |
-	       shadow_user_mask | shadow_x_mask | shadow_me_mask;
+	spte |= __pa(child_pt) | shadow_present_mask | PT_WRITABLE_MASK |
+		shadow_user_mask | shadow_x_mask | shadow_me_mask;
 
 	if (ad_disabled)
-		spte |= SPTE_AD_DISABLED_MASK;
+		spte |= SPTE_TDP_AD_DISABLED_MASK;
 	else
 		spte |= shadow_accessed_mask;
 
@@ -188,7 +204,7 @@ u64 kvm_mmu_changed_pte_notifier_make_spte(u64 old_spte, kvm_pfn_t new_pfn)
 	new_spte |= (u64)new_pfn << PAGE_SHIFT;
 
 	new_spte &= ~PT_WRITABLE_MASK;
-	new_spte &= ~SPTE_HOST_WRITEABLE;
+	new_spte &= ~shadow_host_writable_mask;
 
 	new_spte = mark_spte_for_access_track(new_spte);
 
@@ -242,53 +258,68 @@ u64 mark_spte_for_access_track(u64 spte)
 	return spte;
 }
 
-void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 access_mask)
+void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 mmio_mask, u64 access_mask)
 {
 	BUG_ON((u64)(unsigned)access_mask != access_mask);
-	WARN_ON(mmio_value & (shadow_nonpresent_or_rsvd_mask << SHADOW_NONPRESENT_OR_RSVD_MASK_LEN));
 	WARN_ON(mmio_value & shadow_nonpresent_or_rsvd_lower_gfn_mask);
-	shadow_mmio_value = mmio_value | SPTE_MMIO_MASK;
+
+	if (!enable_mmio_caching)
+		mmio_value = 0;
+
+	/*
+	 * Disable MMIO caching if the MMIO value collides with the bits that
+	 * are used to hold the relocated GFN when the L1TF mitigation is
+	 * enabled.  This should never fire as there is no known hardware that
+	 * can trigger this condition, e.g. SME/SEV CPUs that require a custom
+	 * MMIO value are not susceptible to L1TF.
+	 */
+	if (WARN_ON(mmio_value & (shadow_nonpresent_or_rsvd_mask <<
+				  SHADOW_NONPRESENT_OR_RSVD_MASK_LEN)))
+		mmio_value = 0;
+
+	/*
+	 * The masked MMIO value must obviously match itself and a removed SPTE
+	 * must not get a false positive.  Removed SPTEs and MMIO SPTEs should
+	 * never collide as MMIO must set some RWX bits, and removed SPTEs must
+	 * not set any RWX bits.
+	 */
+	if (WARN_ON((mmio_value & mmio_mask) != mmio_value) ||
+	    WARN_ON(mmio_value && (REMOVED_SPTE & mmio_mask) == mmio_value))
+		mmio_value = 0;
+
+	shadow_mmio_value = mmio_value;
+	shadow_mmio_mask  = mmio_mask;
 	shadow_mmio_access_mask = access_mask;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mmio_spte_mask);
 
-/*
- * Sets the shadow PTE masks used by the MMU.
- *
- * Assumptions:
- *  - Setting either @accessed_mask or @dirty_mask requires setting both
- *  - At least one of @accessed_mask or @acc_track_mask must be set
- */
-void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
-		u64 dirty_mask, u64 nx_mask, u64 x_mask, u64 p_mask,
-		u64 acc_track_mask, u64 me_mask)
+void kvm_mmu_set_ept_masks(bool has_ad_bits, bool has_exec_only)
 {
-	BUG_ON(!dirty_mask != !accessed_mask);
-	BUG_ON(!accessed_mask && !acc_track_mask);
-	BUG_ON(acc_track_mask & SPTE_SPECIAL_MASK);
-
-	shadow_user_mask = user_mask;
-	shadow_accessed_mask = accessed_mask;
-	shadow_dirty_mask = dirty_mask;
-	shadow_nx_mask = nx_mask;
-	shadow_x_mask = x_mask;
-	shadow_present_mask = p_mask;
-	shadow_acc_track_mask = acc_track_mask;
-	shadow_me_mask = me_mask;
+	shadow_user_mask	= VMX_EPT_READABLE_MASK;
+	shadow_accessed_mask	= has_ad_bits ? VMX_EPT_ACCESS_BIT : 0ull;
+	shadow_dirty_mask	= has_ad_bits ? VMX_EPT_DIRTY_BIT : 0ull;
+	shadow_nx_mask		= 0ull;
+	shadow_x_mask		= VMX_EPT_EXECUTABLE_MASK;
+	shadow_present_mask	= has_exec_only ? 0ull : VMX_EPT_READABLE_MASK;
+	shadow_acc_track_mask	= VMX_EPT_RWX_MASK;
+	shadow_me_mask		= 0ull;
+
+	shadow_host_writable_mask = EPT_SPTE_HOST_WRITABLE;
+	shadow_mmu_writable_mask  = EPT_SPTE_MMU_WRITABLE;
+
+	/*
+	 * EPT Misconfigurations are generated if the value of bits 2:0
+	 * of an EPT paging-structure entry is 110b (write/execute).
+	 */
+	kvm_mmu_set_mmio_spte_mask(VMX_EPT_MISCONFIG_WX_VALUE,
+				   VMX_EPT_RWX_MASK, 0);
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_set_mask_ptes);
+EXPORT_SYMBOL_GPL(kvm_mmu_set_ept_masks);
 
 void kvm_mmu_reset_all_pte_masks(void)
 {
 	u8 low_phys_bits;
-
-	shadow_user_mask = 0;
-	shadow_accessed_mask = 0;
-	shadow_dirty_mask = 0;
-	shadow_nx_mask = 0;
-	shadow_x_mask = 0;
-	shadow_present_mask = 0;
-	shadow_acc_track_mask = 0;
+	u64 mask;
 
 	shadow_phys_bits = kvm_get_shadow_phys_bits();
 
@@ -315,4 +346,30 @@ void kvm_mmu_reset_all_pte_masks(void)
 
 	shadow_nonpresent_or_rsvd_lower_gfn_mask =
 		GENMASK_ULL(low_phys_bits - 1, PAGE_SHIFT);
+
+	shadow_user_mask	= PT_USER_MASK;
+	shadow_accessed_mask	= PT_ACCESSED_MASK;
+	shadow_dirty_mask	= PT_DIRTY_MASK;
+	shadow_nx_mask		= PT64_NX_MASK;
+	shadow_x_mask		= 0;
+	shadow_present_mask	= PT_PRESENT_MASK;
+	shadow_acc_track_mask	= 0;
+	shadow_me_mask		= sme_me_mask;
+
+	shadow_host_writable_mask = DEFAULT_SPTE_HOST_WRITEABLE;
+	shadow_mmu_writable_mask  = DEFAULT_SPTE_MMU_WRITEABLE;
+
+	/*
+	 * Set a reserved PA bit in MMIO SPTEs to generate page faults with
+	 * PFEC.RSVD=1 on MMIO accesses.  64-bit PTEs (PAE, x86-64, and EPT
+	 * paging) support a maximum of 52 bits of PA, i.e. if the CPU supports
+	 * 52-bit physical addresses then there are no reserved PA bits in the
+	 * PTEs and so the reserved PA approach must be disabled.
+	 */
+	if (shadow_phys_bits < 52)
+		mask = BIT_ULL(51) | PT_PRESENT_MASK;
+	else
+		mask = 0;
+
+	kvm_mmu_set_mmio_spte_mask(mask, mask, ACC_WRITE_MASK | ACC_USER_MASK);
 }
diff --git a/arch/x86/kvm/mmu/spte.h b/arch/x86/kvm/mmu/spte.h
index 6de3950fd704..bca0ba11cccf 100644
--- a/arch/x86/kvm/mmu/spte.h
+++ b/arch/x86/kvm/mmu/spte.h
@@ -5,18 +5,33 @@
 
 #include "mmu_internal.h"
 
-#define PT_FIRST_AVAIL_BITS_SHIFT 10
-#define PT64_SECOND_AVAIL_BITS_SHIFT 54
+/*
+ * A MMU present SPTE is backed by actual memory and may or may not be present
+ * in hardware.  E.g. MMIO SPTEs are not considered present.  Use bit 11, as it
+ * is ignored by all flavors of SPTEs and checking a low bit often generates
+ * better code than for a high bit, e.g. 56+.  MMU present checks are pervasive
+ * enough that the improved code generation is noticeable in KVM's footprint.
+ */
+#define SPTE_MMU_PRESENT_MASK		BIT_ULL(11)
 
 /*
- * The mask used to denote special SPTEs, which can be either MMIO SPTEs or
- * Access Tracking SPTEs.
+ * TDP SPTES (more specifically, EPT SPTEs) may not have A/D bits, and may also
+ * be restricted to using write-protection (for L2 when CPU dirty logging, i.e.
+ * PML, is enabled).  Use bits 52 and 53 to hold the type of A/D tracking that
+ * is must be employed for a given TDP SPTE.
+ *
+ * Note, the "enabled" mask must be '0', as bits 62:52 are _reserved_ for PAE
+ * paging, including NPT PAE.  This scheme works because legacy shadow paging
+ * is guaranteed to have A/D bits and write-protection is forced only for
+ * TDP with CPU dirty logging (PML).  If NPT ever gains PML-like support, it
+ * must be restricted to 64-bit KVM.
  */
-#define SPTE_SPECIAL_MASK (3ULL << 52)
-#define SPTE_AD_ENABLED_MASK (0ULL << 52)
-#define SPTE_AD_DISABLED_MASK (1ULL << 52)
-#define SPTE_AD_WRPROT_ONLY_MASK (2ULL << 52)
-#define SPTE_MMIO_MASK (3ULL << 52)
+#define SPTE_TDP_AD_SHIFT		52
+#define SPTE_TDP_AD_MASK		(3ULL << SPTE_TDP_AD_SHIFT)
+#define SPTE_TDP_AD_ENABLED_MASK	(0ULL << SPTE_TDP_AD_SHIFT)
+#define SPTE_TDP_AD_DISABLED_MASK	(1ULL << SPTE_TDP_AD_SHIFT)
+#define SPTE_TDP_AD_WRPROT_ONLY_MASK	(2ULL << SPTE_TDP_AD_SHIFT)
+static_assert(SPTE_TDP_AD_ENABLED_MASK == 0);
 
 #ifdef CONFIG_DYNAMIC_PHYSICAL_MASK
 #define PT64_BASE_ADDR_MASK (physical_mask & ~(u64)(PAGE_SIZE-1))
@@ -51,16 +66,46 @@
 	(((address) >> PT64_LEVEL_SHIFT(level)) & ((1 << PT64_LEVEL_BITS) - 1))
 #define SHADOW_PT_INDEX(addr, level) PT64_INDEX(addr, level)
 
+/* Bits 9 and 10 are ignored by all non-EPT PTEs. */
+#define DEFAULT_SPTE_HOST_WRITEABLE	BIT_ULL(9)
+#define DEFAULT_SPTE_MMU_WRITEABLE	BIT_ULL(10)
+
+/*
+ * The mask/shift to use for saving the original R/X bits when marking the PTE
+ * as not-present for access tracking purposes. We do not save the W bit as the
+ * PTEs being access tracked also need to be dirty tracked, so the W bit will be
+ * restored only when a write is attempted to the page.  This mask obviously
+ * must not overlap the A/D type mask.
+ */
+#define SHADOW_ACC_TRACK_SAVED_BITS_MASK (PT64_EPT_READABLE_MASK | \
+					  PT64_EPT_EXECUTABLE_MASK)
+#define SHADOW_ACC_TRACK_SAVED_BITS_SHIFT 54
+#define SHADOW_ACC_TRACK_SAVED_MASK	(SHADOW_ACC_TRACK_SAVED_BITS_MASK << \
+					 SHADOW_ACC_TRACK_SAVED_BITS_SHIFT)
+static_assert(!(SPTE_TDP_AD_MASK & SHADOW_ACC_TRACK_SAVED_MASK));
+
+/*
+ * Low ignored bits are at a premium for EPT, use high ignored bits, taking care
+ * to not overlap the A/D type mask or the saved access bits of access-tracked
+ * SPTEs when A/D bits are disabled.
+ */
+#define EPT_SPTE_HOST_WRITABLE		BIT_ULL(57)
+#define EPT_SPTE_MMU_WRITABLE		BIT_ULL(58)
 
-#define SPTE_HOST_WRITEABLE	(1ULL << PT_FIRST_AVAIL_BITS_SHIFT)
-#define SPTE_MMU_WRITEABLE	(1ULL << (PT_FIRST_AVAIL_BITS_SHIFT + 1))
+static_assert(!(EPT_SPTE_HOST_WRITABLE & SPTE_TDP_AD_MASK));
+static_assert(!(EPT_SPTE_MMU_WRITABLE & SPTE_TDP_AD_MASK));
+static_assert(!(EPT_SPTE_HOST_WRITABLE & SHADOW_ACC_TRACK_SAVED_MASK));
+static_assert(!(EPT_SPTE_MMU_WRITABLE & SHADOW_ACC_TRACK_SAVED_MASK));
+
+/* Defined only to keep the above static asserts readable. */
+#undef SHADOW_ACC_TRACK_SAVED_MASK
 
 /*
- * Due to limited space in PTEs, the MMIO generation is a 18 bit subset of
+ * Due to limited space in PTEs, the MMIO generation is a 19 bit subset of
  * the memslots generation and is derived as follows:
  *
- * Bits 0-8 of the MMIO generation are propagated to spte bits 3-11
- * Bits 9-17 of the MMIO generation are propagated to spte bits 54-62
+ * Bits 0-7 of the MMIO generation are propagated to spte bits 3-10
+ * Bits 8-18 of the MMIO generation are propagated to spte bits 52-62
  *
  * The KVM_MEMSLOT_GEN_UPDATE_IN_PROGRESS flag is intentionally not included in
  * the MMIO generation number, as doing so would require stealing a bit from
@@ -71,39 +116,44 @@
  */
 
 #define MMIO_SPTE_GEN_LOW_START		3
-#define MMIO_SPTE_GEN_LOW_END		11
+#define MMIO_SPTE_GEN_LOW_END		10
 
-#define MMIO_SPTE_GEN_HIGH_START	PT64_SECOND_AVAIL_BITS_SHIFT
+#define MMIO_SPTE_GEN_HIGH_START	52
 #define MMIO_SPTE_GEN_HIGH_END		62
 
 #define MMIO_SPTE_GEN_LOW_MASK		GENMASK_ULL(MMIO_SPTE_GEN_LOW_END, \
 						    MMIO_SPTE_GEN_LOW_START)
 #define MMIO_SPTE_GEN_HIGH_MASK		GENMASK_ULL(MMIO_SPTE_GEN_HIGH_END, \
 						    MMIO_SPTE_GEN_HIGH_START)
+static_assert(!(SPTE_MMU_PRESENT_MASK &
+		(MMIO_SPTE_GEN_LOW_MASK | MMIO_SPTE_GEN_HIGH_MASK)));
 
 #define MMIO_SPTE_GEN_LOW_BITS		(MMIO_SPTE_GEN_LOW_END - MMIO_SPTE_GEN_LOW_START + 1)
 #define MMIO_SPTE_GEN_HIGH_BITS		(MMIO_SPTE_GEN_HIGH_END - MMIO_SPTE_GEN_HIGH_START + 1)
 
 /* remember to adjust the comment above as well if you change these */
-static_assert(MMIO_SPTE_GEN_LOW_BITS == 9 && MMIO_SPTE_GEN_HIGH_BITS == 9);
+static_assert(MMIO_SPTE_GEN_LOW_BITS == 8 && MMIO_SPTE_GEN_HIGH_BITS == 11);
 
 #define MMIO_SPTE_GEN_LOW_SHIFT		(MMIO_SPTE_GEN_LOW_START - 0)
 #define MMIO_SPTE_GEN_HIGH_SHIFT	(MMIO_SPTE_GEN_HIGH_START - MMIO_SPTE_GEN_LOW_BITS)
 
 #define MMIO_SPTE_GEN_MASK		GENMASK_ULL(MMIO_SPTE_GEN_LOW_BITS + MMIO_SPTE_GEN_HIGH_BITS - 1, 0)
 
+extern u64 __read_mostly shadow_host_writable_mask;
+extern u64 __read_mostly shadow_mmu_writable_mask;
 extern u64 __read_mostly shadow_nx_mask;
 extern u64 __read_mostly shadow_x_mask; /* mutual exclusive with nx_mask */
 extern u64 __read_mostly shadow_user_mask;
 extern u64 __read_mostly shadow_accessed_mask;
 extern u64 __read_mostly shadow_dirty_mask;
 extern u64 __read_mostly shadow_mmio_value;
+extern u64 __read_mostly shadow_mmio_mask;
 extern u64 __read_mostly shadow_mmio_access_mask;
 extern u64 __read_mostly shadow_present_mask;
 extern u64 __read_mostly shadow_me_mask;
 
 /*
- * SPTEs used by MMUs without A/D bits are marked with SPTE_AD_DISABLED_MASK;
+ * SPTEs in MMUs without A/D bits are marked with SPTE_TDP_AD_DISABLED_MASK;
  * shadow_acc_track_mask is the set of bits to be cleared in non-accessed
  * pages.
  */
@@ -121,28 +171,21 @@ extern u64 __read_mostly shadow_nonpresent_or_rsvd_mask;
 #define SHADOW_NONPRESENT_OR_RSVD_MASK_LEN 5
 
 /*
- * The mask/shift to use for saving the original R/X bits when marking the PTE
- * as not-present for access tracking purposes. We do not save the W bit as the
- * PTEs being access tracked also need to be dirty tracked, so the W bit will be
- * restored only when a write is attempted to the page.
- */
-#define SHADOW_ACC_TRACK_SAVED_BITS_MASK (PT64_EPT_READABLE_MASK | \
-					  PT64_EPT_EXECUTABLE_MASK)
-#define SHADOW_ACC_TRACK_SAVED_BITS_SHIFT PT64_SECOND_AVAIL_BITS_SHIFT
-
-/*
  * If a thread running without exclusive control of the MMU lock must perform a
  * multi-part operation on an SPTE, it can set the SPTE to REMOVED_SPTE as a
  * non-present intermediate value. Other threads which encounter this value
  * should not modify the SPTE.
  *
- * This constant works because it is considered non-present on both AMD and
- * Intel CPUs and does not create a L1TF vulnerability because the pfn section
- * is zeroed out.
+ * Use a semi-arbitrary value that doesn't set RWX bits, i.e. is not-present on
+ * bot AMD and Intel CPUs, and doesn't set PFN bits, i.e. doesn't create a L1TF
+ * vulnerability.  Use only low bits to avoid 64-bit immediates.
  *
  * Only used by the TDP MMU.
  */
-#define REMOVED_SPTE (1ull << 59)
+#define REMOVED_SPTE	0x5a0ULL
+
+/* Removed SPTEs must not be misconstrued as shadow present PTEs. */
+static_assert(!(REMOVED_SPTE & SPTE_MMU_PRESENT_MASK));
 
 static inline bool is_removed_spte(u64 spte)
 {
@@ -167,7 +210,13 @@ extern u8 __read_mostly shadow_phys_bits;
 
 static inline bool is_mmio_spte(u64 spte)
 {
-	return (spte & SPTE_SPECIAL_MASK) == SPTE_MMIO_MASK;
+	return (spte & shadow_mmio_mask) == shadow_mmio_value &&
+	       likely(shadow_mmio_value);
+}
+
+static inline bool is_shadow_present_pte(u64 pte)
+{
+	return !!(pte & SPTE_MMU_PRESENT_MASK);
 }
 
 static inline bool sp_ad_disabled(struct kvm_mmu_page *sp)
@@ -177,25 +226,30 @@ static inline bool sp_ad_disabled(struct kvm_mmu_page *sp)
 
 static inline bool spte_ad_enabled(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
-	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_DISABLED_MASK;
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
+	return (spte & SPTE_TDP_AD_MASK) != SPTE_TDP_AD_DISABLED_MASK;
 }
 
 static inline bool spte_ad_need_write_protect(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
-	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_ENABLED_MASK;
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
+	/*
+	 * This is benign for non-TDP SPTEs as SPTE_TDP_AD_ENABLED_MASK is '0',
+	 * and non-TDP SPTEs will never set these bits.  Optimize for 64-bit
+	 * TDP and do the A/D type check unconditionally.
+	 */
+	return (spte & SPTE_TDP_AD_MASK) != SPTE_TDP_AD_ENABLED_MASK;
 }
 
 static inline u64 spte_shadow_accessed_mask(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
 	return spte_ad_enabled(spte) ? shadow_accessed_mask : 0;
 }
 
 static inline u64 spte_shadow_dirty_mask(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
 	return spte_ad_enabled(spte) ? shadow_dirty_mask : 0;
 }
 
@@ -204,11 +258,6 @@ static inline bool is_access_track_spte(u64 spte)
 	return !spte_ad_enabled(spte) && (spte & shadow_acc_track_mask) == 0;
 }
 
-static inline bool is_shadow_present_pte(u64 pte)
-{
-	return (pte != 0) && !is_mmio_spte(pte) && !is_removed_spte(pte);
-}
-
 static inline bool is_large_pte(u64 pte)
 {
 	return pte & PT_PAGE_SIZE_MASK;
@@ -246,8 +295,8 @@ static inline bool is_dirty_spte(u64 spte)
 
 static inline bool spte_can_locklessly_be_made_writable(u64 spte)
 {
-	return (spte & (SPTE_HOST_WRITEABLE | SPTE_MMU_WRITEABLE)) ==
-		(SPTE_HOST_WRITEABLE | SPTE_MMU_WRITEABLE);
+	return (spte & shadow_host_writable_mask) &&
+	       (spte & shadow_mmu_writable_mask);
 }
 
 static inline u64 get_mmio_spte_generation(u64 spte)
diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index 34207b874886..88f69a6cc492 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -27,6 +27,15 @@ void kvm_mmu_init_tdp_mmu(struct kvm *kvm)
 	INIT_LIST_HEAD(&kvm->arch.tdp_mmu_pages);
 }
 
+static __always_inline void kvm_lockdep_assert_mmu_lock_held(struct kvm *kvm,
+							     bool shared)
+{
+	if (shared)
+		lockdep_assert_held_read(&kvm->mmu_lock);
+	else
+		lockdep_assert_held_write(&kvm->mmu_lock);
+}
+
 void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
 {
 	if (!kvm->arch.tdp_mmu_enabled)
@@ -41,32 +50,85 @@ void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
 	rcu_barrier();
 }
 
-static void tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root)
+static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+			  gfn_t start, gfn_t end, bool can_yield, bool flush,
+			  bool shared);
+
+static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
 {
-	if (kvm_mmu_put_root(kvm, root))
-		kvm_tdp_mmu_free_root(kvm, root);
+	free_page((unsigned long)sp->spt);
+	kmem_cache_free(mmu_page_header_cache, sp);
 }
 
-static inline bool tdp_mmu_next_root_valid(struct kvm *kvm,
-					   struct kvm_mmu_page *root)
+/*
+ * This is called through call_rcu in order to free TDP page table memory
+ * safely with respect to other kernel threads that may be operating on
+ * the memory.
+ * By only accessing TDP MMU page table memory in an RCU read critical
+ * section, and freeing it after a grace period, lockless access to that
+ * memory won't use it after it is freed.
+ */
+static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
 {
-	lockdep_assert_held_write(&kvm->mmu_lock);
+	struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
+					       rcu_head);
 
-	if (list_entry_is_head(root, &kvm->arch.tdp_mmu_roots, link))
-		return false;
+	tdp_mmu_free_sp(sp);
+}
 
-	kvm_mmu_get_root(kvm, root);
-	return true;
+void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
+			  bool shared)
+{
+	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+
+	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
 
+	if (!refcount_dec_and_test(&root->tdp_mmu_root_count))
+		return;
+
+	WARN_ON(!root->tdp_mmu_page);
+
+	spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+	list_del_rcu(&root->link);
+	spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
+
+	zap_gfn_range(kvm, root, 0, max_gfn, false, false, shared);
+
+	call_rcu(&root->rcu_head, tdp_mmu_free_sp_rcu_callback);
 }
 
-static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
-						     struct kvm_mmu_page *root)
+/*
+ * Finds the next valid root after root (or the first valid root if root
+ * is NULL), takes a reference on it, and returns that next root. If root
+ * is not NULL, this thread should have already taken a reference on it, and
+ * that reference will be dropped. If no valid root is found, this
+ * function will return NULL.
+ */
+static struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
+					      struct kvm_mmu_page *prev_root,
+					      bool shared)
 {
 	struct kvm_mmu_page *next_root;
 
-	next_root = list_next_entry(root, link);
-	tdp_mmu_put_root(kvm, root);
+	rcu_read_lock();
+
+	if (prev_root)
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &prev_root->link,
+						  typeof(*prev_root), link);
+	else
+		next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						   typeof(*next_root), link);
+
+	while (next_root && !kvm_tdp_mmu_get_root(kvm, next_root))
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+				&next_root->link, typeof(*next_root), link);
+
+	rcu_read_unlock();
+
+	if (prev_root)
+		kvm_tdp_mmu_put_root(kvm, prev_root, shared);
+
 	return next_root;
 }
 
@@ -75,35 +137,24 @@ static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
  * This makes it safe to release the MMU lock and yield within the loop, but
  * if exiting the loop early, the caller must drop the reference to the most
  * recent root. (Unless keeping a live reference is desirable.)
+ *
+ * If shared is set, this function is operating under the MMU lock in read
+ * mode. In the unlikely event that this thread must free a root, the lock
+ * will be temporarily dropped and reacquired in write mode.
  */
-#define for_each_tdp_mmu_root_yield_safe(_kvm, _root)				\
-	for (_root = list_first_entry(&_kvm->arch.tdp_mmu_roots,	\
-				      typeof(*_root), link);		\
-	     tdp_mmu_next_root_valid(_kvm, _root);			\
-	     _root = tdp_mmu_next_root(_kvm, _root))
-
-#define for_each_tdp_mmu_root(_kvm, _root)				\
-	list_for_each_entry(_root, &_kvm->arch.tdp_mmu_roots, link)
-
-static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-			  gfn_t start, gfn_t end, bool can_yield, bool flush);
-
-void kvm_tdp_mmu_free_root(struct kvm *kvm, struct kvm_mmu_page *root)
-{
-	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-
-	lockdep_assert_held_write(&kvm->mmu_lock);
-
-	WARN_ON(root->root_count);
-	WARN_ON(!root->tdp_mmu_page);
-
-	list_del(&root->link);
-
-	zap_gfn_range(kvm, root, 0, max_gfn, false, false);
-
-	free_page((unsigned long)root->spt);
-	kmem_cache_free(mmu_page_header_cache, root);
-}
+#define for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared)	\
+	for (_root = tdp_mmu_next_root(_kvm, NULL, _shared);		\
+	     _root;							\
+	     _root = tdp_mmu_next_root(_kvm, _root, _shared))		\
+		if (kvm_mmu_page_as_id(_root) != _as_id) {		\
+		} else
+
+#define for_each_tdp_mmu_root(_kvm, _root, _as_id)				\
+	list_for_each_entry_rcu(_root, &_kvm->arch.tdp_mmu_roots, link,		\
+				lockdep_is_held_type(&kvm->mmu_lock, 0) ||	\
+				lockdep_is_held(&kvm->arch.tdp_mmu_pages_lock))	\
+		if (kvm_mmu_page_as_id(_root) != _as_id) {		\
+		} else
 
 static union kvm_mmu_page_role page_role_for_level(struct kvm_vcpu *vcpu,
 						   int level)
@@ -137,81 +188,46 @@ static struct kvm_mmu_page *alloc_tdp_mmu_page(struct kvm_vcpu *vcpu, gfn_t gfn,
 	return sp;
 }
 
-static struct kvm_mmu_page *get_tdp_mmu_vcpu_root(struct kvm_vcpu *vcpu)
+hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
 {
 	union kvm_mmu_page_role role;
 	struct kvm *kvm = vcpu->kvm;
 	struct kvm_mmu_page *root;
 
-	role = page_role_for_level(vcpu, vcpu->arch.mmu->shadow_root_level);
+	lockdep_assert_held_write(&kvm->mmu_lock);
 
-	write_lock(&kvm->mmu_lock);
+	role = page_role_for_level(vcpu, vcpu->arch.mmu->shadow_root_level);
 
 	/* Check for an existing root before allocating a new one. */
-	for_each_tdp_mmu_root(kvm, root) {
-		if (root->role.word == role.word) {
-			kvm_mmu_get_root(kvm, root);
-			write_unlock(&kvm->mmu_lock);
-			return root;
-		}
+	for_each_tdp_mmu_root(kvm, root, kvm_mmu_role_as_id(role)) {
+		if (root->role.word == role.word &&
+		    kvm_tdp_mmu_get_root(kvm, root))
+			goto out;
 	}
 
 	root = alloc_tdp_mmu_page(vcpu, 0, vcpu->arch.mmu->shadow_root_level);
-	root->root_count = 1;
-
-	list_add(&root->link, &kvm->arch.tdp_mmu_roots);
-
-	write_unlock(&kvm->mmu_lock);
-
-	return root;
-}
-
-hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
-{
-	struct kvm_mmu_page *root;
+	refcount_set(&root->tdp_mmu_root_count, 1);
 
-	root = get_tdp_mmu_vcpu_root(vcpu);
-	if (!root)
-		return INVALID_PAGE;
+	spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+	list_add_rcu(&root->link, &kvm->arch.tdp_mmu_roots);
+	spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
 
+out:
 	return __pa(root->spt);
 }
 
-static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
-{
-	free_page((unsigned long)sp->spt);
-	kmem_cache_free(mmu_page_header_cache, sp);
-}
-
-/*
- * This is called through call_rcu in order to free TDP page table memory
- * safely with respect to other kernel threads that may be operating on
- * the memory.
- * By only accessing TDP MMU page table memory in an RCU read critical
- * section, and freeing it after a grace period, lockless access to that
- * memory won't use it after it is freed.
- */
-static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
-{
-	struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
-					       rcu_head);
-
-	tdp_mmu_free_sp(sp);
-}
-
 static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 				u64 old_spte, u64 new_spte, int level,
 				bool shared);
 
 static void handle_changed_spte_acc_track(u64 old_spte, u64 new_spte, int level)
 {
-	bool pfn_changed = spte_to_pfn(old_spte) != spte_to_pfn(new_spte);
-
 	if (!is_shadow_present_pte(old_spte) || !is_last_spte(old_spte, level))
 		return;
 
 	if (is_accessed_spte(old_spte) &&
-	    (!is_accessed_spte(new_spte) || pfn_changed))
+	    (!is_shadow_present_pte(new_spte) || !is_accessed_spte(new_spte) ||
+	     spte_to_pfn(old_spte) != spte_to_pfn(new_spte)))
 		kvm_set_pfn_accessed(spte_to_pfn(old_spte));
 }
 
@@ -455,7 +471,7 @@ static void __handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 
 
 	if (was_leaf && is_dirty_spte(old_spte) &&
-	    (!is_dirty_spte(new_spte) || pfn_changed))
+	    (!is_present || !is_dirty_spte(new_spte) || pfn_changed))
 		kvm_set_pfn_dirty(spte_to_pfn(old_spte));
 
 	/*
@@ -479,8 +495,9 @@ static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 }
 
 /*
- * tdp_mmu_set_spte_atomic - Set a TDP MMU SPTE atomically and handle the
- * associated bookkeeping
+ * tdp_mmu_set_spte_atomic_no_dirty_log - Set a TDP MMU SPTE atomically
+ * and handle the associated bookkeeping, but do not mark the page dirty
+ * in KVM's dirty bitmaps.
  *
  * @kvm: kvm instance
  * @iter: a tdp_iter instance currently on the SPTE that should be set
@@ -488,9 +505,9 @@ static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
  * Returns: true if the SPTE was set, false if it was not. If false is returned,
  *	    this function will have no side-effects.
  */
-static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
-					   struct tdp_iter *iter,
-					   u64 new_spte)
+static inline bool tdp_mmu_set_spte_atomic_no_dirty_log(struct kvm *kvm,
+							struct tdp_iter *iter,
+							u64 new_spte)
 {
 	lockdep_assert_held_read(&kvm->mmu_lock);
 
@@ -498,19 +515,32 @@ static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
 	 * Do not change removed SPTEs. Only the thread that froze the SPTE
 	 * may modify it.
 	 */
-	if (iter->old_spte == REMOVED_SPTE)
+	if (is_removed_spte(iter->old_spte))
 		return false;
 
 	if (cmpxchg64(rcu_dereference(iter->sptep), iter->old_spte,
 		      new_spte) != iter->old_spte)
 		return false;
 
-	handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
-			    new_spte, iter->level, true);
+	__handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
+			      new_spte, iter->level, true);
+	handle_changed_spte_acc_track(iter->old_spte, new_spte, iter->level);
 
 	return true;
 }
 
+static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
+					   struct tdp_iter *iter,
+					   u64 new_spte)
+{
+	if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, iter, new_spte))
+		return false;
+
+	handle_changed_spte_dirty_log(kvm, iter->as_id, iter->gfn,
+				      iter->old_spte, new_spte, iter->level);
+	return true;
+}
+
 static inline bool tdp_mmu_zap_spte_atomic(struct kvm *kvm,
 					   struct tdp_iter *iter)
 {
@@ -569,7 +599,7 @@ static inline void __tdp_mmu_set_spte(struct kvm *kvm, struct tdp_iter *iter,
 	 * should be used. If operating under the MMU lock in write mode, the
 	 * use of the removed SPTE should not be necessary.
 	 */
-	WARN_ON(iter->old_spte == REMOVED_SPTE);
+	WARN_ON(is_removed_spte(iter->old_spte));
 
 	WRITE_ONCE(*rcu_dereference(iter->sptep), new_spte);
 
@@ -634,7 +664,8 @@ static inline void tdp_mmu_set_spte_no_dirty_log(struct kvm *kvm,
  * Return false if a yield was not needed.
  */
 static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
-					     struct tdp_iter *iter, bool flush)
+					     struct tdp_iter *iter, bool flush,
+					     bool shared)
 {
 	/* Ensure forward progress has been made before yielding. */
 	if (iter->next_last_level_gfn == iter->yielded_gfn)
@@ -646,7 +677,11 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
 		if (flush)
 			kvm_flush_remote_tlbs(kvm);
 
-		cond_resched_rwlock_write(&kvm->mmu_lock);
+		if (shared)
+			cond_resched_rwlock_read(&kvm->mmu_lock);
+		else
+			cond_resched_rwlock_write(&kvm->mmu_lock);
+
 		rcu_read_lock();
 
 		WARN_ON(iter->gfn > iter->next_last_level_gfn);
@@ -664,24 +699,32 @@ static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
  * If can_yield is true, will release the MMU lock and reschedule if the
  * scheduler needs the CPU or there is contention on the MMU lock. If this
  * function cannot yield, it will not release the MMU lock or reschedule and
  * the caller must ensure it does not supply too large a GFN range, or the
- * operation can cause a soft lockup.  Note, in some use cases a flush may be
- * required by prior actions.  Ensure the pending flush is performed prior to
- * yielding.
+ * operation can cause a soft lockup.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU lock in write mode.
  */
 static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-			  gfn_t start, gfn_t end, bool can_yield, bool flush)
+			  gfn_t start, gfn_t end, bool can_yield, bool flush,
+			  bool shared)
 {
 	struct tdp_iter iter;
 
+	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
+
 	rcu_read_lock();
 
 	tdp_root_for_each_pte(iter, root, start, end) {
+retry:
 		if (can_yield &&
-		    tdp_mmu_iter_cond_resched(kvm, &iter, flush)) {
+		    tdp_mmu_iter_cond_resched(kvm, &iter, flush, shared)) {
 			flush = false;
 			continue;
 		}
@@ -699,8 +742,17 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 		    !is_last_spte(iter.old_spte, iter.level))
 			continue;
 
-		tdp_mmu_set_spte(kvm, &iter, 0);
-		flush = true;
+		if (!shared) {
+			tdp_mmu_set_spte(kvm, &iter, 0);
+			flush = true;
+		} else if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 	}
 
 	rcu_read_unlock();
@@ -712,15 +764,21 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU in write mode.
  */
-bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
-				 bool can_yield)
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
+				 gfn_t end, bool can_yield, bool flush,
+				 bool shared)
 {
 	struct kvm_mmu_page *root;
-	bool flush = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root)
-		flush = zap_gfn_range(kvm, root, start, end, can_yield, flush);
+	for_each_tdp_mmu_root_yield_safe(kvm, root, as_id, shared)
+		flush = zap_gfn_range(kvm, root, start, end, can_yield, flush,
+				      shared);
 
 	return flush;
 }
@@ -728,14 +786,116 @@ bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
 void kvm_tdp_mmu_zap_all(struct kvm *kvm)
 {
 	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-	bool flush;
+	bool flush = false;
+	int i;
+
+	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
+		flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, 0, max_gfn,
+						  flush, false);
+
+	if (flush)
+		kvm_flush_remote_tlbs(kvm);
+}
+
+static struct kvm_mmu_page *next_invalidated_root(struct kvm *kvm,
+						  struct kvm_mmu_page *prev_root)
+{
+	struct kvm_mmu_page *next_root;
+
+	if (prev_root)
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &prev_root->link,
+						  typeof(*prev_root), link);
+	else
+		next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						   typeof(*next_root), link);
+
+	while (next_root && !(next_root->role.invalid &&
+			      refcount_read(&next_root->tdp_mmu_root_count)))
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &next_root->link,
+						  typeof(*next_root), link);
+
+	return next_root;
+}
+
+/*
+ * Since kvm_tdp_mmu_zap_all_fast has acquired a reference to each
+ * invalidated root, they will not be freed until this function drops the
+ * reference. Before dropping that reference, tear down the paging
+ * structure so that whichever thread does drop the last reference
+ * only has to do a trivial amount of work. Since the roots are invalid,
+ * no new SPTEs should be created under them.
+ */
+void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm)
+{
+	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+	struct kvm_mmu_page *next_root;
+	struct kvm_mmu_page *root;
+	bool flush = false;
+
+	lockdep_assert_held_read(&kvm->mmu_lock);
+
+	rcu_read_lock();
+
+	root = next_invalidated_root(kvm, NULL);
+
+	while (root) {
+		next_root = next_invalidated_root(kvm, root);
+
+		rcu_read_unlock();
+
+		flush = zap_gfn_range(kvm, root, 0, max_gfn, true, flush,
+				      true);
+
+		/*
+		 * Put the reference acquired in
+		 * kvm_tdp_mmu_invalidate_roots
+		 */
+		kvm_tdp_mmu_put_root(kvm, root, true);
+
+		root = next_root;
+
+		rcu_read_lock();
+	}
+
+	rcu_read_unlock();
 
-	flush = kvm_tdp_mmu_zap_gfn_range(kvm, 0, max_gfn);
 	if (flush)
 		kvm_flush_remote_tlbs(kvm);
 }
 
 /*
+ * Mark each TDP MMU root as invalid so that other threads
+ * will drop their references and allow the root count to
+ * go to 0.
+ *
+ * Also take a reference on all roots so that this thread
+ * can do the bulk of the work required to free the roots
+ * once they are invalidated. Without this reference, a
+ * vCPU thread might drop the last reference to a root and
+ * get stuck with tearing down the entire paging structure.
+ *
+ * Roots which have a zero refcount should be skipped as
+ * they're already being torn down.
+ * Already invalid roots should be referenced again so that
+ * they aren't freed before kvm_tdp_mmu_zap_all_fast is
+ * done with them.
+ *
+ * This has essentially the same effect for the TDP MMU
+ * as updating mmu_valid_gen does for the shadow MMU.
+ */
+void kvm_tdp_mmu_invalidate_all_roots(struct kvm *kvm)
+{
+	struct kvm_mmu_page *root;
+
+	lockdep_assert_held_write(&kvm->mmu_lock);
+	list_for_each_entry(root, &kvm->arch.tdp_mmu_roots, link)
+		if (refcount_inc_not_zero(&root->tdp_mmu_root_count))
+			root->role.invalid = true;
+}
+
+/*
  * Installs a last-level SPTE to handle a TDP page fault.
  * (NPT/EPT violation/misconfiguration)
  */
@@ -777,12 +937,11 @@ static int tdp_mmu_map_handle_target_level(struct kvm_vcpu *vcpu, int write,
 		trace_mark_mmio_spte(rcu_dereference(iter->sptep), iter->gfn,
 				     new_spte);
 		ret = RET_PF_EMULATE;
-	} else
+	} else {
 		trace_kvm_mmu_set_spte(iter->level, iter->gfn,
 				       rcu_dereference(iter->sptep));
+	}
 
-	trace_kvm_mmu_set_spte(iter->level, iter->gfn,
-			       rcu_dereference(iter->sptep));
 	if (!prefault)
 		vcpu->stat.pf_fixed++;
 
@@ -882,199 +1041,139 @@ int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 	return ret;
 }
 
-static __always_inline int
-kvm_tdp_mmu_handle_hva_range(struct kvm *kvm,
-			     unsigned long start,
-			     unsigned long end,
-			     unsigned long data,
-			     int (*handler)(struct kvm *kvm,
-					    struct kvm_memory_slot *slot,
-					    struct kvm_mmu_page *root,
-					    gfn_t start,
-					    gfn_t end,
-					    unsigned long data))
+bool kvm_tdp_mmu_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range,
+				 bool flush)
 {
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
 	struct kvm_mmu_page *root;
-	int ret = 0;
-	int as_id;
-
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		as_id = kvm_mmu_page_as_id(root);
-		slots = __kvm_memslots(kvm, as_id);
-		kvm_for_each_memslot(memslot, slots) {
-			unsigned long hva_start, hva_end;
-			gfn_t gfn_start, gfn_end;
-
-			hva_start = max(start, memslot->userspace_addr);
-			hva_end = min(end, memslot->userspace_addr +
-				      (memslot->npages << PAGE_SHIFT));
-			if (hva_start >= hva_end)
-				continue;
-			/*
-			 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-			 * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-			 */
-			gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-			gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
 
-			ret |= handler(kvm, memslot, root, gfn_start,
-				       gfn_end, data);
-		}
-	}
+	for_each_tdp_mmu_root(kvm, root, range->slot->as_id)
+		flush |= zap_gfn_range(kvm, root, range->start, range->end,
+				       range->may_block, flush, false);
 
-	return ret;
+	return flush;
 }
 
-static int zap_gfn_range_hva_wrapper(struct kvm *kvm,
-				     struct kvm_memory_slot *slot,
-				     struct kvm_mmu_page *root, gfn_t start,
-				     gfn_t end, unsigned long unused)
-{
-	return zap_gfn_range(kvm, root, start, end, false, false);
-}
+typedef bool (*tdp_handler_t)(struct kvm *kvm, struct tdp_iter *iter,
+			      struct kvm_gfn_range *range);
 
-int kvm_tdp_mmu_zap_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end)
+static __always_inline bool kvm_tdp_mmu_handle_gfn(struct kvm *kvm,
+						   struct kvm_gfn_range *range,
+						   tdp_handler_t handler)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-					    zap_gfn_range_hva_wrapper);
+	struct kvm_mmu_page *root;
+	struct tdp_iter iter;
+	bool ret = false;
+
+	rcu_read_lock();
+
+	/*
+	 * Don't support rescheduling, none of the MMU notifiers that funnel
+	 * into this helper allow blocking; it'd be dead, wasteful code.
+	 */
+	for_each_tdp_mmu_root(kvm, root, range->slot->as_id) {
+		tdp_root_for_each_leaf_pte(iter, root, range->start, range->end)
+			ret |= handler(kvm, &iter, range);
+	}
+
+	rcu_read_unlock();
+
+	return ret;
 }
 
 /*
  * Mark the SPTEs range of GFNs [start, end) unaccessed and return non-zero
  * if any of the GFNs in the range have been accessed.
  */
-static int age_gfn_range(struct kvm *kvm, struct kvm_memory_slot *slot,
-			 struct kvm_mmu_page *root, gfn_t start, gfn_t end,
-			 unsigned long unused)
+static bool age_gfn_range(struct kvm *kvm, struct tdp_iter *iter,
+			  struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-	int young = 0;
 	u64 new_spte = 0;
 
-	rcu_read_lock();
+	/* If we have a non-accessed entry we don't need to change the pte. */
+	if (!is_accessed_spte(iter->old_spte))
+		return false;
 
-	tdp_root_for_each_leaf_pte(iter, root, start, end) {
+	new_spte = iter->old_spte;
+
+	if (spte_ad_enabled(new_spte)) {
+		new_spte &= ~shadow_accessed_mask;
+	} else {
 		/*
-		 * If we have a non-accessed entry we don't need to change the
-		 * pte.
+		 * Capture the dirty status of the page, so that it doesn't get
+		 * lost when the SPTE is marked for access tracking.
 		 */
-		if (!is_accessed_spte(iter.old_spte))
-			continue;
-
-		new_spte = iter.old_spte;
-
-		if (spte_ad_enabled(new_spte)) {
-			clear_bit((ffs(shadow_accessed_mask) - 1),
-				  (unsigned long *)&new_spte);
-		} else {
-			/*
-			 * Capture the dirty status of the page, so that it doesn't get
-			 * lost when the SPTE is marked for access tracking.
-			 */
-			if (is_writable_pte(new_spte))
-				kvm_set_pfn_dirty(spte_to_pfn(new_spte));
+		if (is_writable_pte(new_spte))
+			kvm_set_pfn_dirty(spte_to_pfn(new_spte));
 
-			new_spte = mark_spte_for_access_track(new_spte);
-		}
-		new_spte &= ~shadow_dirty_mask;
-
-		tdp_mmu_set_spte_no_acc_track(kvm, &iter, new_spte);
-		young = 1;
-
-		trace_kvm_age_page(iter.gfn, iter.level, slot, young);
+		new_spte = mark_spte_for_access_track(new_spte);
 	}
 
-	rcu_read_unlock();
+	tdp_mmu_set_spte_no_acc_track(kvm, iter, new_spte);
 
-	return young;
+	return true;
 }
 
-int kvm_tdp_mmu_age_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end)
+bool kvm_tdp_mmu_age_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-					    age_gfn_range);
+	return kvm_tdp_mmu_handle_gfn(kvm, range, age_gfn_range);
 }
 
-static int test_age_gfn(struct kvm *kvm, struct kvm_memory_slot *slot,
-			struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-			unsigned long unused2)
+static bool test_age_gfn(struct kvm *kvm, struct tdp_iter *iter,
+			 struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-
-	tdp_root_for_each_leaf_pte(iter, root, gfn, gfn + 1)
-		if (is_accessed_spte(iter.old_spte))
-			return 1;
-
-	return 0;
+	return is_accessed_spte(iter->old_spte);
 }
 
-int kvm_tdp_mmu_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_tdp_mmu_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, hva, hva + 1, 0,
-					    test_age_gfn);
+	return kvm_tdp_mmu_handle_gfn(kvm, range, test_age_gfn);
 }
 
-/*
- * Handle the changed_pte MMU notifier for the TDP MMU.
- * data is a pointer to the new pte_t mapping the HVA specified by the MMU
- * notifier.
- * Returns non-zero if a flush is needed before releasing the MMU lock.
- */
-static int set_tdp_spte(struct kvm *kvm, struct kvm_memory_slot *slot,
-			struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-			unsigned long data)
+static bool set_spte_gfn(struct kvm *kvm, struct tdp_iter *iter,
+			 struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-	pte_t *ptep = (pte_t *)data;
-	kvm_pfn_t new_pfn;
 	u64 new_spte;
-	int need_flush = 0;
-
-	rcu_read_lock();
 
-	WARN_ON(pte_huge(*ptep));
+	/* Huge pages aren't expected to be modified without first being zapped. */
+	WARN_ON(pte_huge(range->pte) || range->start + 1 != range->end);
 
-	new_pfn = pte_pfn(*ptep);
-
-	tdp_root_for_each_pte(iter, root, gfn, gfn + 1) {
-		if (iter.level != PG_LEVEL_4K)
-			continue;
-
-		if (!is_shadow_present_pte(iter.old_spte))
-			break;
-
-		tdp_mmu_set_spte(kvm, &iter, 0);
-
-		kvm_flush_remote_tlbs_with_address(kvm, iter.gfn, 1);
+	if (iter->level != PG_LEVEL_4K ||
+	    !is_shadow_present_pte(iter->old_spte))
+		return false;
 
-		if (!pte_write(*ptep)) {
-			new_spte = kvm_mmu_changed_pte_notifier_make_spte(
-					iter.old_spte, new_pfn);
+	/*
+	 * Note, when changing a read-only SPTE, it's not strictly necessary to
+	 * zero the SPTE before setting the new PFN, but doing so preserves the
+	 * invariant that the PFN of a present * leaf SPTE can never change.
+	 * See __handle_changed_spte().
+	 */
+	tdp_mmu_set_spte(kvm, iter, 0);
 
-			tdp_mmu_set_spte(kvm, &iter, new_spte);
-		}
+	if (!pte_write(range->pte)) {
+		new_spte = kvm_mmu_changed_pte_notifier_make_spte(iter->old_spte,
+								  pte_pfn(range->pte));
 
-		need_flush = 1;
+		tdp_mmu_set_spte(kvm, iter, new_spte);
 	}
 
-	if (need_flush)
-		kvm_flush_remote_tlbs_with_address(kvm, gfn, 1);
-
-	rcu_read_unlock();
-
-	return 0;
+	return true;
 }
 
-int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
-			     pte_t *host_ptep)
+/*
+ * Handle the changed_pte MMU notifier for the TDP MMU.
+ * data is a pointer to the new pte_t mapping the HVA specified by the MMU
+ * notifier.
+ * Returns non-zero if a flush is needed before releasing the MMU lock.
+ */
+bool kvm_tdp_mmu_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, address, address + 1,
-					    (unsigned long)host_ptep,
-					    set_tdp_spte);
+	bool flush = kvm_tdp_mmu_handle_gfn(kvm, range, set_spte_gfn);
+
+	/* FIXME: return 'flush' instead of flushing here. */
+	if (flush)
+		kvm_flush_remote_tlbs_with_address(kvm, range->start, 1);
+
+	return false;
 }
 
 /*
@@ -1095,7 +1194,8 @@ static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
 	for_each_tdp_pte_min_level(iter, root->spt, root->role.level,
 				   min_level, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
 
 		if (!is_shadow_present_pte(iter.old_spte) ||
@@ -1105,7 +1205,15 @@ static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 
 		new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
 
-		tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+		if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+							  new_spte)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 		spte_set = true;
 	}
 
@@ -1122,17 +1230,13 @@ bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
 			     int min_level)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
 		spte_set |= wrprot_gfn_range(kvm, root, slot->base_gfn,
 			     slot->base_gfn + slot->npages, min_level);
-	}
 
 	return spte_set;
 }
@@ -1154,7 +1258,8 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 	rcu_read_lock();
 
 	tdp_root_for_each_leaf_pte(iter, root, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
 
 		if (spte_ad_need_write_protect(iter.old_spte)) {
@@ -1169,7 +1274,15 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 				continue;
 		}
 
-		tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+		if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+							  new_spte)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 		spte_set = true;
 	}
 
@@ -1187,17 +1300,13 @@ static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 bool kvm_tdp_mmu_clear_dirty_slot(struct kvm *kvm, struct kvm_memory_slot *slot)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
 		spte_set |= clear_dirty_gfn_range(kvm, root, slot->base_gfn,
 				slot->base_gfn + slot->npages);
-	}
 
 	return spte_set;
 }
@@ -1259,37 +1368,32 @@ void kvm_tdp_mmu_clear_dirty_pt_masked(struct kvm *kvm,
 				       bool wrprot)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	for_each_tdp_mmu_root(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
-
+	for_each_tdp_mmu_root(kvm, root, slot->as_id)
 		clear_dirty_pt_masked(kvm, root, gfn, mask, wrprot);
-	}
 }
 
 /*
  * Clear leaf entries which could be replaced by large mappings, for
  * GFNs within the slot.
  */
-static void zap_collapsible_spte_range(struct kvm *kvm,
+static bool zap_collapsible_spte_range(struct kvm *kvm,
 				       struct kvm_mmu_page *root,
-				       struct kvm_memory_slot *slot)
+				       const struct kvm_memory_slot *slot,
+				       bool flush)
 {
 	gfn_t start = slot->base_gfn;
 	gfn_t end = start + slot->npages;
 	struct tdp_iter iter;
 	kvm_pfn_t pfn;
-	bool spte_set = false;
 
 	rcu_read_lock();
 
 	tdp_root_for_each_pte(iter, root, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, spte_set)) {
-			spte_set = false;
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, flush, true)) {
+			flush = false;
 			continue;
 		}
 
@@ -1303,38 +1407,43 @@ static void zap_collapsible_spte_range(struct kvm *kvm,
 							    pfn, PG_LEVEL_NUM))
 			continue;
 
-		tdp_mmu_set_spte(kvm, &iter, 0);
-
-		spte_set = true;
+		if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
+		flush = true;
 	}
 
 	rcu_read_unlock();
-	if (spte_set)
-		kvm_flush_remote_tlbs(kvm);
+
+	return flush;
 }
 
 /*
  * Clear non-leaf entries (and free associated page tables) which could
  * be replaced by large mappings, for GFNs within the slot.
  */
-void kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
-				       struct kvm_memory_slot *slot)
+bool kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
+				       const struct kvm_memory_slot *slot,
+				       bool flush)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
-		zap_collapsible_spte_range(kvm, root, slot);
-	}
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
+		flush = zap_collapsible_spte_range(kvm, root, slot, flush);
+
+	return flush;
 }
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
@@ -1351,7 +1460,7 @@ static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
 			break;
 
 		new_spte = iter.old_spte &
-			~(PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE);
+			~(PT_WRITABLE_MASK | shadow_mmu_writable_mask);
 
 		tdp_mmu_set_spte(kvm, &iter, new_spte);
 		spte_set = true;
@@ -1364,24 +1473,19 @@ static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 bool kvm_tdp_mmu_write_protect_gfn(struct kvm *kvm,
 				   struct kvm_memory_slot *slot, gfn_t gfn)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	for_each_tdp_mmu_root(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
-
+	for_each_tdp_mmu_root(kvm, root, slot->as_id)
 		spte_set |= write_protect_gfn(kvm, root, gfn);
-	}
+
 	return spte_set;
 }
 
diff --git a/arch/x86/kvm/mmu/tdp_mmu.h b/arch/x86/kvm/mmu/tdp_mmu.h
index 31096ece9b14..5fdf63090451 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.h
+++ b/arch/x86/kvm/mmu/tdp_mmu.h
@@ -6,14 +6,28 @@
 #include <linux/kvm_host.h>
 
 hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu);
-void kvm_tdp_mmu_free_root(struct kvm *kvm, struct kvm_mmu_page *root);
 
-bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
-				 bool can_yield);
-static inline bool kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start,
-					     gfn_t end)
+__must_check static inline bool kvm_tdp_mmu_get_root(struct kvm *kvm,
+						     struct kvm_mmu_page *root)
 {
-	return __kvm_tdp_mmu_zap_gfn_range(kvm, start, end, true);
+	if (root->role.invalid)
+		return false;
+
+	return refcount_inc_not_zero(&root->tdp_mmu_root_count);
+}
+
+void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
+			  bool shared);
+
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
+				 gfn_t end, bool can_yield, bool flush,
+				 bool shared);
+static inline bool kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id,
+					     gfn_t start, gfn_t end, bool flush,
+					     bool shared)
+{
+	return __kvm_tdp_mmu_zap_gfn_range(kvm, as_id, start, end, true, flush,
+					   shared);
 }
 static inline bool kvm_tdp_mmu_zap_sp(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
@@ -29,23 +43,23 @@ static inline bool kvm_tdp_mmu_zap_sp(struct kvm *kvm, struct kvm_mmu_page *sp)
 	 * of the shadow page's gfn range and stop iterating before yielding.
 	 */
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	return __kvm_tdp_mmu_zap_gfn_range(kvm, sp->gfn, end, false);
+	return __kvm_tdp_mmu_zap_gfn_range(kvm, kvm_mmu_page_as_id(sp),
+					   sp->gfn, end, false, false, false);
 }
+
 void kvm_tdp_mmu_zap_all(struct kvm *kvm);
+void kvm_tdp_mmu_invalidate_all_roots(struct kvm *kvm);
+void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm);
 
 int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 		    int map_writable, int max_level, kvm_pfn_t pfn,
 		    bool prefault);
 
-int kvm_tdp_mmu_zap_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end);
-
-int kvm_tdp_mmu_age_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end);
-int kvm_tdp_mmu_test_age_hva(struct kvm *kvm, unsigned long hva);
-
-int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
-			     pte_t *host_ptep);
+bool kvm_tdp_mmu_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range,
+				 bool flush);
+bool kvm_tdp_mmu_age_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range);
+bool kvm_tdp_mmu_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range);
+bool kvm_tdp_mmu_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range);
 
 bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
 			     int min_level);
@@ -55,8 +69,9 @@ void kvm_tdp_mmu_clear_dirty_pt_masked(struct kvm *kvm,
 				       struct kvm_memory_slot *slot,
 				       gfn_t gfn, unsigned long mask,
 				       bool wrprot);
-void kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
-				       struct kvm_memory_slot *slot);
+bool kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
+				       const struct kvm_memory_slot *slot,
+				       bool flush);
 
 bool kvm_tdp_mmu_write_protect_gfn(struct kvm *kvm,
 				   struct kvm_memory_slot *slot, gfn_t gfn);
diff --git a/arch/x86/kvm/reverse_cpuid.h b/arch/x86/kvm/reverse_cpuid.h
new file mode 100644
index 000000000000..a19d473d0184
--- /dev/null
+++ b/arch/x86/kvm/reverse_cpuid.h
@@ -0,0 +1,186 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef ARCH_X86_KVM_REVERSE_CPUID_H
+#define ARCH_X86_KVM_REVERSE_CPUID_H
+
+#include <uapi/asm/kvm.h>
+#include <asm/cpufeature.h>
+#include <asm/cpufeatures.h>
+
+/*
+ * Hardware-defined CPUID leafs that are scattered in the kernel, but need to
+ * be directly used by KVM.  Note, these word values conflict with the kernel's
+ * "bug" caps, but KVM doesn't use those.
+ */
+enum kvm_only_cpuid_leafs {
+	CPUID_12_EAX	 = NCAPINTS,
+	NR_KVM_CPU_CAPS,
+
+	NKVMCAPINTS = NR_KVM_CPU_CAPS - NCAPINTS,
+};
+
+#define KVM_X86_FEATURE(w, f)		((w)*32 + (f))
+
+/* Intel-defined SGX sub-features, CPUID level 0x12 (EAX). */
+#define KVM_X86_FEATURE_SGX1		KVM_X86_FEATURE(CPUID_12_EAX, 0)
+#define KVM_X86_FEATURE_SGX2		KVM_X86_FEATURE(CPUID_12_EAX, 1)
+
+struct cpuid_reg {
+	u32 function;
+	u32 index;
+	int reg;
+};
+
+static const struct cpuid_reg reverse_cpuid[] = {
+	[CPUID_1_EDX]         = {         1, 0, CPUID_EDX},
+	[CPUID_8000_0001_EDX] = {0x80000001, 0, CPUID_EDX},
+	[CPUID_8086_0001_EDX] = {0x80860001, 0, CPUID_EDX},
+	[CPUID_1_ECX]         = {         1, 0, CPUID_ECX},
+	[CPUID_C000_0001_EDX] = {0xc0000001, 0, CPUID_EDX},
+	[CPUID_8000_0001_ECX] = {0x80000001, 0, CPUID_ECX},
+	[CPUID_7_0_EBX]       = {         7, 0, CPUID_EBX},
+	[CPUID_D_1_EAX]       = {       0xd, 1, CPUID_EAX},
+	[CPUID_8000_0008_EBX] = {0x80000008, 0, CPUID_EBX},
+	[CPUID_6_EAX]         = {         6, 0, CPUID_EAX},
+	[CPUID_8000_000A_EDX] = {0x8000000a, 0, CPUID_EDX},
+	[CPUID_7_ECX]         = {         7, 0, CPUID_ECX},
+	[CPUID_8000_0007_EBX] = {0x80000007, 0, CPUID_EBX},
+	[CPUID_7_EDX]         = {         7, 0, CPUID_EDX},
+	[CPUID_7_1_EAX]       = {         7, 1, CPUID_EAX},
+	[CPUID_12_EAX]        = {0x00000012, 0, CPUID_EAX},
+	[CPUID_8000_001F_EAX] = {0x8000001f, 0, CPUID_EAX},
+};
+
+/*
+ * Reverse CPUID and its derivatives can only be used for hardware-defined
+ * feature words, i.e. words whose bits directly correspond to a CPUID leaf.
+ * Retrieving a feature bit or masking guest CPUID from a Linux-defined word
+ * is nonsensical as the bit number/mask is an arbitrary software-defined value
+ * and can't be used by KVM to query/control guest capabilities.  And obviously
+ * the leaf being queried must have an entry in the lookup table.
+ */
+static __always_inline void reverse_cpuid_check(unsigned int x86_leaf)
+{
+	BUILD_BUG_ON(x86_leaf == CPUID_LNX_1);
+	BUILD_BUG_ON(x86_leaf == CPUID_LNX_2);
+	BUILD_BUG_ON(x86_leaf == CPUID_LNX_3);
+	BUILD_BUG_ON(x86_leaf == CPUID_LNX_4);
+	BUILD_BUG_ON(x86_leaf >= ARRAY_SIZE(reverse_cpuid));
+	BUILD_BUG_ON(reverse_cpuid[x86_leaf].function == 0);
+}
+
+/*
+ * Translate feature bits that are scattered in the kernel's cpufeatures word
+ * into KVM feature words that align with hardware's definitions.
+ */
+static __always_inline u32 __feature_translate(int x86_feature)
+{
+	if (x86_feature == X86_FEATURE_SGX1)
+		return KVM_X86_FEATURE_SGX1;
+	else if (x86_feature == X86_FEATURE_SGX2)
+		return KVM_X86_FEATURE_SGX2;
+
+	return x86_feature;
+}
+
+static __always_inline u32 __feature_leaf(int x86_feature)
+{
+	return __feature_translate(x86_feature) / 32;
+}
+
+/*
+ * Retrieve the bit mask from an X86_FEATURE_* definition.  Features contain
+ * the hardware defined bit number (stored in bits 4:0) and a software defined
+ * "word" (stored in bits 31:5).  The word is used to index into arrays of
+ * bit masks that hold the per-cpu feature capabilities, e.g. this_cpu_has().
+ */
+static __always_inline u32 __feature_bit(int x86_feature)
+{
+	x86_feature = __feature_translate(x86_feature);
+
+	reverse_cpuid_check(x86_feature / 32);
+	return 1 << (x86_feature & 31);
+}
+
+#define feature_bit(name)  __feature_bit(X86_FEATURE_##name)
+
+static __always_inline struct cpuid_reg x86_feature_cpuid(unsigned int x86_feature)
+{
+	unsigned int x86_leaf = __feature_leaf(x86_feature);
+
+	reverse_cpuid_check(x86_leaf);
+	return reverse_cpuid[x86_leaf];
+}
+
+static __always_inline u32 *__cpuid_entry_get_reg(struct kvm_cpuid_entry2 *entry,
+						  u32 reg)
+{
+	switch (reg) {
+	case CPUID_EAX:
+		return &entry->eax;
+	case CPUID_EBX:
+		return &entry->ebx;
+	case CPUID_ECX:
+		return &entry->ecx;
+	case CPUID_EDX:
+		return &entry->edx;
+	default:
+		BUILD_BUG();
+		return NULL;
+	}
+}
+
+static __always_inline u32 *cpuid_entry_get_reg(struct kvm_cpuid_entry2 *entry,
+						unsigned int x86_feature)
+{
+	const struct cpuid_reg cpuid = x86_feature_cpuid(x86_feature);
+
+	return __cpuid_entry_get_reg(entry, cpuid.reg);
+}
+
+static __always_inline u32 cpuid_entry_get(struct kvm_cpuid_entry2 *entry,
+					   unsigned int x86_feature)
+{
+	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
+
+	return *reg & __feature_bit(x86_feature);
+}
+
+static __always_inline bool cpuid_entry_has(struct kvm_cpuid_entry2 *entry,
+					    unsigned int x86_feature)
+{
+	return cpuid_entry_get(entry, x86_feature);
+}
+
+static __always_inline void cpuid_entry_clear(struct kvm_cpuid_entry2 *entry,
+					      unsigned int x86_feature)
+{
+	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
+
+	*reg &= ~__feature_bit(x86_feature);
+}
+
+static __always_inline void cpuid_entry_set(struct kvm_cpuid_entry2 *entry,
+					    unsigned int x86_feature)
+{
+	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
+
+	*reg |= __feature_bit(x86_feature);
+}
+
+static __always_inline void cpuid_entry_change(struct kvm_cpuid_entry2 *entry,
+					       unsigned int x86_feature,
+					       bool set)
+{
+	u32 *reg = cpuid_entry_get_reg(entry, x86_feature);
+
+	/*
+	 * Open coded instead of using cpuid_entry_{clear,set}() to coerce the
+	 * compiler into using CMOV instead of Jcc when possible.
+	 */
+	if (set)
+		*reg |= __feature_bit(x86_feature);
+	else
+		*reg &= ~__feature_bit(x86_feature);
+}
+
+#endif /* ARCH_X86_KVM_REVERSE_CPUID_H */
diff --git a/arch/x86/kvm/svm/avic.c b/arch/x86/kvm/svm/avic.c
index 3e55674098be..712b4e0de481 100644
--- a/arch/x86/kvm/svm/avic.c
+++ b/arch/x86/kvm/svm/avic.c
@@ -270,7 +270,7 @@ static int avic_init_backing_page(struct kvm_vcpu *vcpu)
 	if (id >= AVIC_MAX_PHYSICAL_ID_COUNT)
 		return -EINVAL;
 
-	if (!svm->vcpu.arch.apic->regs)
+	if (!vcpu->arch.apic->regs)
 		return -EINVAL;
 
 	if (kvm_apicv_activated(vcpu->kvm)) {
@@ -281,7 +281,7 @@ static int avic_init_backing_page(struct kvm_vcpu *vcpu)
 			return ret;
 	}
 
-	svm->avic_backing_page = virt_to_page(svm->vcpu.arch.apic->regs);
+	svm->avic_backing_page = virt_to_page(vcpu->arch.apic->regs);
 
 	/* Setting AVIC backing page address in the phy APIC ID table */
 	entry = avic_get_physical_id_entry(vcpu, id);
@@ -315,15 +315,16 @@ static void avic_kick_target_vcpus(struct kvm *kvm, struct kvm_lapic *source,
 	}
 }
 
-int avic_incomplete_ipi_interception(struct vcpu_svm *svm)
+int avic_incomplete_ipi_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u32 icrh = svm->vmcb->control.exit_info_1 >> 32;
 	u32 icrl = svm->vmcb->control.exit_info_1;
 	u32 id = svm->vmcb->control.exit_info_2 >> 32;
 	u32 index = svm->vmcb->control.exit_info_2 & 0xFF;
-	struct kvm_lapic *apic = svm->vcpu.arch.apic;
+	struct kvm_lapic *apic = vcpu->arch.apic;
 
-	trace_kvm_avic_incomplete_ipi(svm->vcpu.vcpu_id, icrh, icrl, id, index);
+	trace_kvm_avic_incomplete_ipi(vcpu->vcpu_id, icrh, icrl, id, index);
 
 	switch (id) {
 	case AVIC_IPI_FAILURE_INVALID_INT_TYPE:
@@ -347,11 +348,11 @@ int avic_incomplete_ipi_interception(struct vcpu_svm *svm)
 		 * set the appropriate IRR bits on the valid target
 		 * vcpus. So, we just need to kick the appropriate vcpu.
 		 */
-		avic_kick_target_vcpus(svm->vcpu.kvm, apic, icrl, icrh);
+		avic_kick_target_vcpus(vcpu->kvm, apic, icrl, icrh);
 		break;
 	case AVIC_IPI_FAILURE_INVALID_TARGET:
 		WARN_ONCE(1, "Invalid IPI target: index=%u, vcpu=%d, icr=%#0x:%#0x\n",
-			  index, svm->vcpu.vcpu_id, icrh, icrl);
+			  index, vcpu->vcpu_id, icrh, icrl);
 		break;
 	case AVIC_IPI_FAILURE_INVALID_BACKING_PAGE:
 		WARN_ONCE(1, "Invalid backing page\n");
@@ -539,8 +540,9 @@ static bool is_avic_unaccelerated_access_trap(u32 offset)
 	return ret;
 }
 
-int avic_unaccelerated_access_interception(struct vcpu_svm *svm)
+int avic_unaccelerated_access_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int ret = 0;
 	u32 offset = svm->vmcb->control.exit_info_1 &
 		     AVIC_UNACCEL_ACCESS_OFFSET_MASK;
@@ -550,7 +552,7 @@ int avic_unaccelerated_access_interception(struct vcpu_svm *svm)
 		     AVIC_UNACCEL_ACCESS_WRITE_MASK;
 	bool trap = is_avic_unaccelerated_access_trap(offset);
 
-	trace_kvm_avic_unaccelerated_access(svm->vcpu.vcpu_id, offset,
+	trace_kvm_avic_unaccelerated_access(vcpu->vcpu_id, offset,
 					    trap, write, vector);
 	if (trap) {
 		/* Handling Trap */
@@ -558,7 +560,7 @@ int avic_unaccelerated_access_interception(struct vcpu_svm *svm)
 		ret = avic_unaccel_trap_write(svm);
 	} else {
 		/* Handling Fault */
-		ret = kvm_emulate_instruction(&svm->vcpu, 0);
+		ret = kvm_emulate_instruction(vcpu, 0);
 	}
 
 	return ret;
@@ -572,7 +574,7 @@ int avic_init_vcpu(struct vcpu_svm *svm)
 	if (!avic || !irqchip_in_kernel(vcpu->kvm))
 		return 0;
 
-	ret = avic_init_backing_page(&svm->vcpu);
+	ret = avic_init_backing_page(vcpu);
 	if (ret)
 		return ret;
 
diff --git a/arch/x86/kvm/svm/nested.c b/arch/x86/kvm/svm/nested.c
index fb204eaa8bb3..540d43ba2cf4 100644
--- a/arch/x86/kvm/svm/nested.c
+++ b/arch/x86/kvm/svm/nested.c
@@ -29,6 +29,8 @@
 #include "lapic.h"
 #include "svm.h"
 
+#define CC KVM_NESTED_VMENTER_CONSISTENCY_CHECK
+
 static void nested_svm_inject_npf_exit(struct kvm_vcpu *vcpu,
 				       struct x86_exception *fault)
 {
@@ -92,12 +94,12 @@ static unsigned long nested_svm_get_tdp_cr3(struct kvm_vcpu *vcpu)
 static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
-	struct vmcb *hsave = svm->nested.hsave;
 
 	WARN_ON(mmu_is_nested(vcpu));
 
 	vcpu->arch.mmu = &vcpu->arch.guest_mmu;
-	kvm_init_shadow_npt_mmu(vcpu, X86_CR0_PG, hsave->save.cr4, hsave->save.efer,
+	kvm_init_shadow_npt_mmu(vcpu, X86_CR0_PG, svm->vmcb01.ptr->save.cr4,
+				svm->vmcb01.ptr->save.efer,
 				svm->nested.ctl.nested_cr3);
 	vcpu->arch.mmu->get_guest_pgd     = nested_svm_get_tdp_cr3;
 	vcpu->arch.mmu->get_pdptr         = nested_svm_get_tdp_pdptr;
@@ -123,7 +125,7 @@ void recalc_intercepts(struct vcpu_svm *svm)
 		return;
 
 	c = &svm->vmcb->control;
-	h = &svm->nested.hsave->control;
+	h = &svm->vmcb01.ptr->control;
 	g = &svm->nested.ctl;
 
 	for (i = 0; i < MAX_INTERCEPT; i++)
@@ -213,44 +215,64 @@ static bool nested_svm_vmrun_msrpm(struct vcpu_svm *svm)
 	return true;
 }
 
-static bool svm_get_nested_state_pages(struct kvm_vcpu *vcpu)
+/*
+ * Bits 11:0 of bitmap address are ignored by hardware
+ */
+static bool nested_svm_check_bitmap_pa(struct kvm_vcpu *vcpu, u64 pa, u32 size)
 {
-	struct vcpu_svm *svm = to_svm(vcpu);
+	u64 addr = PAGE_ALIGN(pa);
 
-	if (WARN_ON(!is_guest_mode(vcpu)))
-		return true;
-
-	if (!nested_svm_vmrun_msrpm(svm)) {
-		vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
-		vcpu->run->internal.suberror =
-			KVM_INTERNAL_ERROR_EMULATION;
-		vcpu->run->internal.ndata = 0;
-		return false;
-	}
-
-	return true;
+	return kvm_vcpu_is_legal_gpa(vcpu, addr) &&
+	    kvm_vcpu_is_legal_gpa(vcpu, addr + size - 1);
 }
 
-static bool nested_vmcb_check_controls(struct vmcb_control_area *control)
+static bool nested_vmcb_check_controls(struct kvm_vcpu *vcpu,
+				       struct vmcb_control_area *control)
 {
-	if ((vmcb_is_intercept(control, INTERCEPT_VMRUN)) == 0)
+	if (CC(!vmcb_is_intercept(control, INTERCEPT_VMRUN)))
 		return false;
 
-	if (control->asid == 0)
+	if (CC(control->asid == 0))
 		return false;
 
-	if ((control->nested_ctl & SVM_NESTED_CTL_NP_ENABLE) &&
-	    !npt_enabled)
+	if (CC((control->nested_ctl & SVM_NESTED_CTL_NP_ENABLE) && !npt_enabled))
+		return false;
+
+	if (CC(!nested_svm_check_bitmap_pa(vcpu, control->msrpm_base_pa,
+					   MSRPM_SIZE)))
+		return false;
+	if (CC(!nested_svm_check_bitmap_pa(vcpu, control->iopm_base_pa,
+					   IOPM_SIZE)))
 		return false;
 
 	return true;
 }
 
-static bool nested_vmcb_check_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
+static bool nested_vmcb_check_cr3_cr4(struct kvm_vcpu *vcpu,
+				      struct vmcb_save_area *save)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
-	bool vmcb12_lma;
+	/*
+	 * These checks are also performed by KVM_SET_SREGS,
+	 * except that EFER.LMA is not checked by SVM against
+	 * CR0.PG && EFER.LME.
+	 */
+	if ((save->efer & EFER_LME) && (save->cr0 & X86_CR0_PG)) {
+		if (CC(!(save->cr4 & X86_CR4_PAE)) ||
+		    CC(!(save->cr0 & X86_CR0_PE)) ||
+		    CC(kvm_vcpu_is_illegal_gpa(vcpu, save->cr3)))
+			return false;
+	}
+
+	if (CC(!kvm_is_valid_cr4(vcpu, save->cr4)))
+		return false;
+
+	return true;
+}
 
+/* Common checks that apply to both L1 and L2 state.  */
+static bool nested_vmcb_valid_sregs(struct kvm_vcpu *vcpu,
+				    struct vmcb_save_area *save)
+{
 	/*
 	 * FIXME: these should be done after copying the fields,
 	 * to avoid TOC/TOU races.  For these save area checks
@@ -258,31 +280,27 @@ static bool nested_vmcb_check_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
 	 * kvm_set_cr4 handle failure; EFER_SVME is an exception
 	 * so it is force-set later in nested_prepare_vmcb_save.
 	 */
-	if ((vmcb12->save.efer & EFER_SVME) == 0)
+	if (CC(!(save->efer & EFER_SVME)))
 		return false;
 
-	if (((vmcb12->save.cr0 & X86_CR0_CD) == 0) && (vmcb12->save.cr0 & X86_CR0_NW))
+	if (CC((save->cr0 & X86_CR0_CD) == 0 && (save->cr0 & X86_CR0_NW)) ||
+	    CC(save->cr0 & ~0xffffffffULL))
 		return false;
 
-	if (!kvm_dr6_valid(vmcb12->save.dr6) || !kvm_dr7_valid(vmcb12->save.dr7))
+	if (CC(!kvm_dr6_valid(save->dr6)) || CC(!kvm_dr7_valid(save->dr7)))
 		return false;
 
-	vmcb12_lma = (vmcb12->save.efer & EFER_LME) && (vmcb12->save.cr0 & X86_CR0_PG);
+	if (!nested_vmcb_check_cr3_cr4(vcpu, save))
+		return false;
 
-	if (vmcb12_lma) {
-		if (!(vmcb12->save.cr4 & X86_CR4_PAE) ||
-		    !(vmcb12->save.cr0 & X86_CR0_PE) ||
-		    kvm_vcpu_is_illegal_gpa(vcpu, vmcb12->save.cr3))
-			return false;
-	}
-	if (!kvm_is_valid_cr4(&svm->vcpu, vmcb12->save.cr4))
+	if (CC(!kvm_valid_efer(vcpu, save->efer)))
 		return false;
 
 	return true;
 }
 
-static void load_nested_vmcb_control(struct vcpu_svm *svm,
-				     struct vmcb_control_area *control)
+static void nested_load_control_from_vmcb12(struct vcpu_svm *svm,
+					    struct vmcb_control_area *control)
 {
 	copy_vmcb_control_area(&svm->nested.ctl, control);
 
@@ -294,9 +312,9 @@ static void load_nested_vmcb_control(struct vcpu_svm *svm,
 
 /*
  * Synchronize fields that are written by the processor, so that
- * they can be copied back into the nested_vmcb.
+ * they can be copied back into the vmcb12.
  */
-void sync_nested_vmcb_control(struct vcpu_svm *svm)
+void nested_sync_control_from_vmcb02(struct vcpu_svm *svm)
 {
 	u32 mask;
 	svm->nested.ctl.event_inj      = svm->vmcb->control.event_inj;
@@ -324,8 +342,8 @@ void sync_nested_vmcb_control(struct vcpu_svm *svm)
  * Transfer any event that L0 or L1 wanted to inject into L2 to
  * EXIT_INT_INFO.
  */
-static void nested_vmcb_save_pending_event(struct vcpu_svm *svm,
-					   struct vmcb *vmcb12)
+static void nested_save_pending_event_to_vmcb12(struct vcpu_svm *svm,
+						struct vmcb *vmcb12)
 {
 	struct kvm_vcpu *vcpu = &svm->vcpu;
 	u32 exit_int_info = 0;
@@ -369,12 +387,12 @@ static inline bool nested_npt_enabled(struct vcpu_svm *svm)
 static int nested_svm_load_cr3(struct kvm_vcpu *vcpu, unsigned long cr3,
 			       bool nested_npt)
 {
-	if (kvm_vcpu_is_illegal_gpa(vcpu, cr3))
+	if (CC(kvm_vcpu_is_illegal_gpa(vcpu, cr3)))
 		return -EINVAL;
 
 	if (!nested_npt && is_pae_paging(vcpu) &&
 	    (cr3 != kvm_read_cr3(vcpu) || pdptrs_changed(vcpu))) {
-		if (!load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3))
+		if (CC(!load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3)))
 			return -EINVAL;
 	}
 
@@ -393,15 +411,42 @@ static int nested_svm_load_cr3(struct kvm_vcpu *vcpu, unsigned long cr3,
 	return 0;
 }
 
-static void nested_prepare_vmcb_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
+void nested_vmcb02_compute_g_pat(struct vcpu_svm *svm)
 {
+	if (!svm->nested.vmcb02.ptr)
+		return;
+
+	/* FIXME: merge g_pat from vmcb01 and vmcb12.  */
+	svm->nested.vmcb02.ptr->save.g_pat = svm->vmcb01.ptr->save.g_pat;
+}
+
+static void nested_vmcb02_prepare_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
+{
+	bool new_vmcb12 = false;
+
+	nested_vmcb02_compute_g_pat(svm);
+
 	/* Load the nested guest state */
-	svm->vmcb->save.es = vmcb12->save.es;
-	svm->vmcb->save.cs = vmcb12->save.cs;
-	svm->vmcb->save.ss = vmcb12->save.ss;
-	svm->vmcb->save.ds = vmcb12->save.ds;
-	svm->vmcb->save.gdtr = vmcb12->save.gdtr;
-	svm->vmcb->save.idtr = vmcb12->save.idtr;
+	if (svm->nested.vmcb12_gpa != svm->nested.last_vmcb12_gpa) {
+		new_vmcb12 = true;
+		svm->nested.last_vmcb12_gpa = svm->nested.vmcb12_gpa;
+	}
+
+	if (unlikely(new_vmcb12 || vmcb_is_dirty(vmcb12, VMCB_SEG))) {
+		svm->vmcb->save.es = vmcb12->save.es;
+		svm->vmcb->save.cs = vmcb12->save.cs;
+		svm->vmcb->save.ss = vmcb12->save.ss;
+		svm->vmcb->save.ds = vmcb12->save.ds;
+		svm->vmcb->save.cpl = vmcb12->save.cpl;
+		vmcb_mark_dirty(svm->vmcb, VMCB_SEG);
+	}
+
+	if (unlikely(new_vmcb12 || vmcb_is_dirty(vmcb12, VMCB_DT))) {
+		svm->vmcb->save.gdtr = vmcb12->save.gdtr;
+		svm->vmcb->save.idtr = vmcb12->save.idtr;
+		vmcb_mark_dirty(svm->vmcb, VMCB_DT);
+	}
+
 	kvm_set_rflags(&svm->vcpu, vmcb12->save.rflags | X86_EFLAGS_FIXED);
 
 	/*
@@ -413,7 +458,9 @@ static void nested_prepare_vmcb_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
 
 	svm_set_cr0(&svm->vcpu, vmcb12->save.cr0);
 	svm_set_cr4(&svm->vcpu, vmcb12->save.cr4);
-	svm->vmcb->save.cr2 = svm->vcpu.arch.cr2 = vmcb12->save.cr2;
+
+	svm->vcpu.arch.cr2 = vmcb12->save.cr2;
+
 	kvm_rax_write(&svm->vcpu, vmcb12->save.rax);
 	kvm_rsp_write(&svm->vcpu, vmcb12->save.rsp);
 	kvm_rip_write(&svm->vcpu, vmcb12->save.rip);
@@ -422,15 +469,41 @@ static void nested_prepare_vmcb_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
 	svm->vmcb->save.rax = vmcb12->save.rax;
 	svm->vmcb->save.rsp = vmcb12->save.rsp;
 	svm->vmcb->save.rip = vmcb12->save.rip;
-	svm->vmcb->save.dr7 = vmcb12->save.dr7 | DR7_FIXED_1;
-	svm->vcpu.arch.dr6  = vmcb12->save.dr6 | DR6_ACTIVE_LOW;
-	svm->vmcb->save.cpl = vmcb12->save.cpl;
+
+	/* These bits will be set properly on the first execution when new_vmc12 is true */
+	if (unlikely(new_vmcb12 || vmcb_is_dirty(vmcb12, VMCB_DR))) {
+		svm->vmcb->save.dr7 = vmcb12->save.dr7 | DR7_FIXED_1;
+		svm->vcpu.arch.dr6  = vmcb12->save.dr6 | DR6_ACTIVE_LOW;
+		vmcb_mark_dirty(svm->vmcb, VMCB_DR);
+	}
 }
 
-static void nested_prepare_vmcb_control(struct vcpu_svm *svm)
+static void nested_vmcb02_prepare_control(struct vcpu_svm *svm)
 {
 	const u32 mask = V_INTR_MASKING_MASK | V_GIF_ENABLE_MASK | V_GIF_MASK;
 
+	/*
+	 * Filled at exit: exit_code, exit_code_hi, exit_info_1, exit_info_2,
+	 * exit_int_info, exit_int_info_err, next_rip, insn_len, insn_bytes.
+	 */
+
+	/*
+	 * Also covers avic_vapic_bar, avic_backing_page, avic_logical_id,
+	 * avic_physical_id.
+	 */
+	WARN_ON(svm->vmcb01.ptr->control.int_ctl & AVIC_ENABLE_MASK);
+
+	/* Copied from vmcb01.  msrpm_base can be overwritten later.  */
+	svm->vmcb->control.nested_ctl = svm->vmcb01.ptr->control.nested_ctl;
+	svm->vmcb->control.iopm_base_pa = svm->vmcb01.ptr->control.iopm_base_pa;
+	svm->vmcb->control.msrpm_base_pa = svm->vmcb01.ptr->control.msrpm_base_pa;
+
+	/* Done at vmrun: asid.  */
+
+	/* Also overwritten later if necessary.  */
+	svm->vmcb->control.tlb_ctl = TLB_CONTROL_DO_NOTHING;
+
+	/* nested_cr3.  */
 	if (nested_npt_enabled(svm))
 		nested_svm_init_mmu_context(&svm->vcpu);
 
@@ -439,7 +512,7 @@ static void nested_prepare_vmcb_control(struct vcpu_svm *svm)
 
 	svm->vmcb->control.int_ctl             =
 		(svm->nested.ctl.int_ctl & ~mask) |
-		(svm->nested.hsave->control.int_ctl & mask);
+		(svm->vmcb01.ptr->control.int_ctl & mask);
 
 	svm->vmcb->control.virt_ext            = svm->nested.ctl.virt_ext;
 	svm->vmcb->control.int_vector          = svm->nested.ctl.int_vector;
@@ -454,17 +527,28 @@ static void nested_prepare_vmcb_control(struct vcpu_svm *svm)
 	enter_guest_mode(&svm->vcpu);
 
 	/*
-	 * Merge guest and host intercepts - must be called  with vcpu in
-	 * guest-mode to take affect here
+	 * Merge guest and host intercepts - must be called with vcpu in
+	 * guest-mode to take effect.
 	 */
 	recalc_intercepts(svm);
+}
 
-	vmcb_mark_all_dirty(svm->vmcb);
+static void nested_svm_copy_common_state(struct vmcb *from_vmcb, struct vmcb *to_vmcb)
+{
+	/*
+	 * Some VMCB state is shared between L1 and L2 and thus has to be
+	 * moved at the time of nested vmrun and vmexit.
+	 *
+	 * VMLOAD/VMSAVE state would also belong in this category, but KVM
+	 * always performs VMLOAD and VMSAVE from the VMCB01.
+	 */
+	to_vmcb->save.spec_ctrl = from_vmcb->save.spec_ctrl;
 }
 
-int enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb12_gpa,
+int enter_svm_guest_mode(struct kvm_vcpu *vcpu, u64 vmcb12_gpa,
 			 struct vmcb *vmcb12)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int ret;
 
 	trace_kvm_nested_vmrun(svm->vmcb->save.rip, vmcb12_gpa,
@@ -482,8 +566,14 @@ int enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb12_gpa,
 
 
 	svm->nested.vmcb12_gpa = vmcb12_gpa;
-	nested_prepare_vmcb_control(svm);
-	nested_prepare_vmcb_save(svm, vmcb12);
+
+	WARN_ON(svm->vmcb == svm->nested.vmcb02.ptr);
+
+	nested_svm_copy_common_state(svm->vmcb01.ptr, svm->nested.vmcb02.ptr);
+
+	svm_switch_vmcb(svm, &svm->nested.vmcb02);
+	nested_vmcb02_prepare_control(svm);
+	nested_vmcb02_prepare_save(svm, vmcb12);
 
 	ret = nested_svm_load_cr3(&svm->vcpu, vmcb12->save.cr3,
 				  nested_npt_enabled(svm));
@@ -491,47 +581,48 @@ int enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb12_gpa,
 		return ret;
 
 	if (!npt_enabled)
-		svm->vcpu.arch.mmu->inject_page_fault = svm_inject_page_fault_nested;
+		vcpu->arch.mmu->inject_page_fault = svm_inject_page_fault_nested;
 
 	svm_set_gif(svm, true);
 
 	return 0;
 }
 
-int nested_svm_vmrun(struct vcpu_svm *svm)
+int nested_svm_vmrun(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int ret;
 	struct vmcb *vmcb12;
-	struct vmcb *hsave = svm->nested.hsave;
-	struct vmcb *vmcb = svm->vmcb;
 	struct kvm_host_map map;
 	u64 vmcb12_gpa;
 
-	if (is_smm(&svm->vcpu)) {
-		kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+	++vcpu->stat.nested_run;
+
+	if (is_smm(vcpu)) {
+		kvm_queue_exception(vcpu, UD_VECTOR);
 		return 1;
 	}
 
 	vmcb12_gpa = svm->vmcb->save.rax;
-	ret = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(vmcb12_gpa), &map);
+	ret = kvm_vcpu_map(vcpu, gpa_to_gfn(vmcb12_gpa), &map);
 	if (ret == -EINVAL) {
-		kvm_inject_gp(&svm->vcpu, 0);
+		kvm_inject_gp(vcpu, 0);
 		return 1;
 	} else if (ret) {
-		return kvm_skip_emulated_instruction(&svm->vcpu);
+		return kvm_skip_emulated_instruction(vcpu);
 	}
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
+	ret = kvm_skip_emulated_instruction(vcpu);
 
 	vmcb12 = map.hva;
 
 	if (WARN_ON_ONCE(!svm->nested.initialized))
 		return -EINVAL;
 
-	load_nested_vmcb_control(svm, &vmcb12->control);
+	nested_load_control_from_vmcb12(svm, &vmcb12->control);
 
-	if (!nested_vmcb_check_save(svm, vmcb12) ||
-	    !nested_vmcb_check_controls(&svm->nested.ctl)) {
+	if (!nested_vmcb_valid_sregs(vcpu, &vmcb12->save) ||
+	    !nested_vmcb_check_controls(vcpu, &svm->nested.ctl)) {
 		vmcb12->control.exit_code    = SVM_EXIT_ERR;
 		vmcb12->control.exit_code_hi = 0;
 		vmcb12->control.exit_info_1  = 0;
@@ -541,36 +632,25 @@ int nested_svm_vmrun(struct vcpu_svm *svm)
 
 
 	/* Clear internal status */
-	kvm_clear_exception_queue(&svm->vcpu);
-	kvm_clear_interrupt_queue(&svm->vcpu);
+	kvm_clear_exception_queue(vcpu);
+	kvm_clear_interrupt_queue(vcpu);
 
 	/*
-	 * Save the old vmcb, so we don't need to pick what we save, but can
-	 * restore everything when a VMEXIT occurs
+	 * Since vmcb01 is not in use, we can use it to store some of the L1
+	 * state.
 	 */
-	hsave->save.es     = vmcb->save.es;
-	hsave->save.cs     = vmcb->save.cs;
-	hsave->save.ss     = vmcb->save.ss;
-	hsave->save.ds     = vmcb->save.ds;
-	hsave->save.gdtr   = vmcb->save.gdtr;
-	hsave->save.idtr   = vmcb->save.idtr;
-	hsave->save.efer   = svm->vcpu.arch.efer;
-	hsave->save.cr0    = kvm_read_cr0(&svm->vcpu);
-	hsave->save.cr4    = svm->vcpu.arch.cr4;
-	hsave->save.rflags = kvm_get_rflags(&svm->vcpu);
-	hsave->save.rip    = kvm_rip_read(&svm->vcpu);
-	hsave->save.rsp    = vmcb->save.rsp;
-	hsave->save.rax    = vmcb->save.rax;
-	if (npt_enabled)
-		hsave->save.cr3    = vmcb->save.cr3;
-	else
-		hsave->save.cr3    = kvm_read_cr3(&svm->vcpu);
-
-	copy_vmcb_control_area(&hsave->control, &vmcb->control);
+	svm->vmcb01.ptr->save.efer   = vcpu->arch.efer;
+	svm->vmcb01.ptr->save.cr0    = kvm_read_cr0(vcpu);
+	svm->vmcb01.ptr->save.cr4    = vcpu->arch.cr4;
+	svm->vmcb01.ptr->save.rflags = kvm_get_rflags(vcpu);
+	svm->vmcb01.ptr->save.rip    = kvm_rip_read(vcpu);
+
+	if (!npt_enabled)
+		svm->vmcb01.ptr->save.cr3 = kvm_read_cr3(vcpu);
 
 	svm->nested.nested_run_pending = 1;
 
-	if (enter_svm_guest_mode(svm, vmcb12_gpa, vmcb12))
+	if (enter_svm_guest_mode(vcpu, vmcb12_gpa, vmcb12))
 		goto out_exit_err;
 
 	if (nested_svm_vmrun_msrpm(svm))
@@ -587,7 +667,7 @@ out_exit_err:
 	nested_svm_vmexit(svm);
 
 out:
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
+	kvm_vcpu_unmap(vcpu, &map, true);
 
 	return ret;
 }
@@ -610,27 +690,30 @@ void nested_svm_vmloadsave(struct vmcb *from_vmcb, struct vmcb *to_vmcb)
 
 int nested_svm_vmexit(struct vcpu_svm *svm)
 {
-	int rc;
+	struct kvm_vcpu *vcpu = &svm->vcpu;
 	struct vmcb *vmcb12;
-	struct vmcb *hsave = svm->nested.hsave;
 	struct vmcb *vmcb = svm->vmcb;
 	struct kvm_host_map map;
+	int rc;
 
-	rc = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(svm->nested.vmcb12_gpa), &map);
+	/* Triple faults in L2 should never escape. */
+	WARN_ON_ONCE(kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu));
+
+	rc = kvm_vcpu_map(vcpu, gpa_to_gfn(svm->nested.vmcb12_gpa), &map);
 	if (rc) {
 		if (rc == -EINVAL)
-			kvm_inject_gp(&svm->vcpu, 0);
+			kvm_inject_gp(vcpu, 0);
 		return 1;
 	}
 
 	vmcb12 = map.hva;
 
 	/* Exit Guest-Mode */
-	leave_guest_mode(&svm->vcpu);
+	leave_guest_mode(vcpu);
 	svm->nested.vmcb12_gpa = 0;
 	WARN_ON_ONCE(svm->nested.nested_run_pending);
 
-	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, &svm->vcpu);
+	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 
 	/* in case we halted in L2 */
 	svm->vcpu.arch.mp_state = KVM_MP_STATE_RUNNABLE;
@@ -644,14 +727,14 @@ int nested_svm_vmexit(struct vcpu_svm *svm)
 	vmcb12->save.gdtr   = vmcb->save.gdtr;
 	vmcb12->save.idtr   = vmcb->save.idtr;
 	vmcb12->save.efer   = svm->vcpu.arch.efer;
-	vmcb12->save.cr0    = kvm_read_cr0(&svm->vcpu);
-	vmcb12->save.cr3    = kvm_read_cr3(&svm->vcpu);
+	vmcb12->save.cr0    = kvm_read_cr0(vcpu);
+	vmcb12->save.cr3    = kvm_read_cr3(vcpu);
 	vmcb12->save.cr2    = vmcb->save.cr2;
 	vmcb12->save.cr4    = svm->vcpu.arch.cr4;
-	vmcb12->save.rflags = kvm_get_rflags(&svm->vcpu);
-	vmcb12->save.rip    = kvm_rip_read(&svm->vcpu);
-	vmcb12->save.rsp    = kvm_rsp_read(&svm->vcpu);
-	vmcb12->save.rax    = kvm_rax_read(&svm->vcpu);
+	vmcb12->save.rflags = kvm_get_rflags(vcpu);
+	vmcb12->save.rip    = kvm_rip_read(vcpu);
+	vmcb12->save.rsp    = kvm_rsp_read(vcpu);
+	vmcb12->save.rax    = kvm_rax_read(vcpu);
 	vmcb12->save.dr7    = vmcb->save.dr7;
 	vmcb12->save.dr6    = svm->vcpu.arch.dr6;
 	vmcb12->save.cpl    = vmcb->save.cpl;
@@ -663,7 +746,7 @@ int nested_svm_vmexit(struct vcpu_svm *svm)
 	vmcb12->control.exit_info_2       = vmcb->control.exit_info_2;
 
 	if (vmcb12->control.exit_code != SVM_EXIT_ERR)
-		nested_vmcb_save_pending_event(svm, vmcb12);
+		nested_save_pending_event_to_vmcb12(svm, vmcb12);
 
 	if (svm->nrips_enabled)
 		vmcb12->control.next_rip  = vmcb->control.next_rip;
@@ -678,37 +761,39 @@ int nested_svm_vmexit(struct vcpu_svm *svm)
 	vmcb12->control.pause_filter_thresh =
 		svm->vmcb->control.pause_filter_thresh;
 
-	/* Restore the original control entries */
-	copy_vmcb_control_area(&vmcb->control, &hsave->control);
+	nested_svm_copy_common_state(svm->nested.vmcb02.ptr, svm->vmcb01.ptr);
+
+	svm_switch_vmcb(svm, &svm->vmcb01);
+	WARN_ON_ONCE(svm->vmcb->control.exit_code != SVM_EXIT_VMRUN);
 
-	/* On vmexit the  GIF is set to false */
+	/*
+	 * On vmexit the  GIF is set to false and
+	 * no event can be injected in L1.
+	 */
 	svm_set_gif(svm, false);
+	svm->vmcb->control.exit_int_info = 0;
 
-	svm->vmcb->control.tsc_offset = svm->vcpu.arch.tsc_offset =
-		svm->vcpu.arch.l1_tsc_offset;
+	svm->vcpu.arch.tsc_offset = svm->vcpu.arch.l1_tsc_offset;
+	if (svm->vmcb->control.tsc_offset != svm->vcpu.arch.tsc_offset) {
+		svm->vmcb->control.tsc_offset = svm->vcpu.arch.tsc_offset;
+		vmcb_mark_dirty(svm->vmcb, VMCB_INTERCEPTS);
+	}
 
 	svm->nested.ctl.nested_cr3 = 0;
 
-	/* Restore selected save entries */
-	svm->vmcb->save.es = hsave->save.es;
-	svm->vmcb->save.cs = hsave->save.cs;
-	svm->vmcb->save.ss = hsave->save.ss;
-	svm->vmcb->save.ds = hsave->save.ds;
-	svm->vmcb->save.gdtr = hsave->save.gdtr;
-	svm->vmcb->save.idtr = hsave->save.idtr;
-	kvm_set_rflags(&svm->vcpu, hsave->save.rflags);
-	kvm_set_rflags(&svm->vcpu, hsave->save.rflags | X86_EFLAGS_FIXED);
-	svm_set_efer(&svm->vcpu, hsave->save.efer);
-	svm_set_cr0(&svm->vcpu, hsave->save.cr0 | X86_CR0_PE);
-	svm_set_cr4(&svm->vcpu, hsave->save.cr4);
-	kvm_rax_write(&svm->vcpu, hsave->save.rax);
-	kvm_rsp_write(&svm->vcpu, hsave->save.rsp);
-	kvm_rip_write(&svm->vcpu, hsave->save.rip);
-	svm->vmcb->save.dr7 = DR7_FIXED_1;
-	svm->vmcb->save.cpl = 0;
-	svm->vmcb->control.exit_int_info = 0;
+	/*
+	 * Restore processor state that had been saved in vmcb01
+	 */
+	kvm_set_rflags(vcpu, svm->vmcb->save.rflags);
+	svm_set_efer(vcpu, svm->vmcb->save.efer);
+	svm_set_cr0(vcpu, svm->vmcb->save.cr0 | X86_CR0_PE);
+	svm_set_cr4(vcpu, svm->vmcb->save.cr4);
+	kvm_rax_write(vcpu, svm->vmcb->save.rax);
+	kvm_rsp_write(vcpu, svm->vmcb->save.rsp);
+	kvm_rip_write(vcpu, svm->vmcb->save.rip);
 
-	vmcb_mark_all_dirty(svm->vmcb);
+	svm->vcpu.arch.dr7 = DR7_FIXED_1;
+	kvm_update_dr7(&svm->vcpu);
 
 	trace_kvm_nested_vmexit_inject(vmcb12->control.exit_code,
 				       vmcb12->control.exit_info_1,
@@ -717,50 +802,62 @@ int nested_svm_vmexit(struct vcpu_svm *svm)
 				       vmcb12->control.exit_int_info_err,
 				       KVM_ISA_SVM);
 
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
+	kvm_vcpu_unmap(vcpu, &map, true);
 
-	nested_svm_uninit_mmu_context(&svm->vcpu);
+	nested_svm_uninit_mmu_context(vcpu);
 
-	rc = nested_svm_load_cr3(&svm->vcpu, hsave->save.cr3, false);
+	rc = nested_svm_load_cr3(vcpu, svm->vmcb->save.cr3, false);
 	if (rc)
 		return 1;
 
-	if (npt_enabled)
-		svm->vmcb->save.cr3 = hsave->save.cr3;
-
 	/*
 	 * Drop what we picked up for L2 via svm_complete_interrupts() so it
 	 * doesn't end up in L1.
 	 */
 	svm->vcpu.arch.nmi_injected = false;
-	kvm_clear_exception_queue(&svm->vcpu);
-	kvm_clear_interrupt_queue(&svm->vcpu);
+	kvm_clear_exception_queue(vcpu);
+	kvm_clear_interrupt_queue(vcpu);
+
+	/*
+	 * If we are here following the completion of a VMRUN that
+	 * is being single-stepped, queue the pending #DB intercept
+	 * right now so that it an be accounted for before we execute
+	 * L1's next instruction.
+	 */
+	if (unlikely(svm->vmcb->save.rflags & X86_EFLAGS_TF))
+		kvm_queue_exception(&(svm->vcpu), DB_VECTOR);
 
 	return 0;
 }
 
+static void nested_svm_triple_fault(struct kvm_vcpu *vcpu)
+{
+	nested_svm_simple_vmexit(to_svm(vcpu), SVM_EXIT_SHUTDOWN);
+}
+
 int svm_allocate_nested(struct vcpu_svm *svm)
 {
-	struct page *hsave_page;
+	struct page *vmcb02_page;
 
 	if (svm->nested.initialized)
 		return 0;
 
-	hsave_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
-	if (!hsave_page)
+	vmcb02_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+	if (!vmcb02_page)
 		return -ENOMEM;
-	svm->nested.hsave = page_address(hsave_page);
+	svm->nested.vmcb02.ptr = page_address(vmcb02_page);
+	svm->nested.vmcb02.pa = __sme_set(page_to_pfn(vmcb02_page) << PAGE_SHIFT);
 
 	svm->nested.msrpm = svm_vcpu_alloc_msrpm();
 	if (!svm->nested.msrpm)
-		goto err_free_hsave;
+		goto err_free_vmcb02;
 	svm_vcpu_init_msrpm(&svm->vcpu, svm->nested.msrpm);
 
 	svm->nested.initialized = true;
 	return 0;
 
-err_free_hsave:
-	__free_page(hsave_page);
+err_free_vmcb02:
+	__free_page(vmcb02_page);
 	return -ENOMEM;
 }
 
@@ -772,8 +869,8 @@ void svm_free_nested(struct vcpu_svm *svm)
 	svm_vcpu_free_msrpm(svm->nested.msrpm);
 	svm->nested.msrpm = NULL;
 
-	__free_page(virt_to_page(svm->nested.hsave));
-	svm->nested.hsave = NULL;
+	__free_page(virt_to_page(svm->nested.vmcb02.ptr));
+	svm->nested.vmcb02.ptr = NULL;
 
 	svm->nested.initialized = false;
 }
@@ -783,18 +880,19 @@ void svm_free_nested(struct vcpu_svm *svm)
  */
 void svm_leave_nested(struct vcpu_svm *svm)
 {
-	if (is_guest_mode(&svm->vcpu)) {
-		struct vmcb *hsave = svm->nested.hsave;
-		struct vmcb *vmcb = svm->vmcb;
+	struct kvm_vcpu *vcpu = &svm->vcpu;
 
+	if (is_guest_mode(vcpu)) {
 		svm->nested.nested_run_pending = 0;
-		leave_guest_mode(&svm->vcpu);
-		copy_vmcb_control_area(&vmcb->control, &hsave->control);
-		nested_svm_uninit_mmu_context(&svm->vcpu);
+		leave_guest_mode(vcpu);
+
+		svm_switch_vmcb(svm, &svm->nested.vmcb02);
+
+		nested_svm_uninit_mmu_context(vcpu);
 		vmcb_mark_all_dirty(svm->vmcb);
 	}
 
-	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, &svm->vcpu);
+	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 }
 
 static int nested_svm_exit_handled_msr(struct vcpu_svm *svm)
@@ -903,16 +1001,15 @@ int nested_svm_exit_handled(struct vcpu_svm *svm)
 	return vmexit;
 }
 
-int nested_svm_check_permissions(struct vcpu_svm *svm)
+int nested_svm_check_permissions(struct kvm_vcpu *vcpu)
 {
-	if (!(svm->vcpu.arch.efer & EFER_SVME) ||
-	    !is_paging(&svm->vcpu)) {
-		kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+	if (!(vcpu->arch.efer & EFER_SVME) || !is_paging(vcpu)) {
+		kvm_queue_exception(vcpu, UD_VECTOR);
 		return 1;
 	}
 
-	if (svm->vmcb->save.cpl) {
-		kvm_inject_gp(&svm->vcpu, 0);
+	if (to_svm(vcpu)->vmcb->save.cpl) {
+		kvm_inject_gp(vcpu, 0);
 		return 1;
 	}
 
@@ -960,50 +1057,11 @@ static void nested_svm_inject_exception_vmexit(struct vcpu_svm *svm)
 	nested_svm_vmexit(svm);
 }
 
-static void nested_svm_smi(struct vcpu_svm *svm)
-{
-	svm->vmcb->control.exit_code = SVM_EXIT_SMI;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
-static void nested_svm_nmi(struct vcpu_svm *svm)
-{
-	svm->vmcb->control.exit_code = SVM_EXIT_NMI;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
-static void nested_svm_intr(struct vcpu_svm *svm)
-{
-	trace_kvm_nested_intr_vmexit(svm->vmcb->save.rip);
-
-	svm->vmcb->control.exit_code   = SVM_EXIT_INTR;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
 static inline bool nested_exit_on_init(struct vcpu_svm *svm)
 {
 	return vmcb_is_intercept(&svm->nested.ctl, INTERCEPT_INIT);
 }
 
-static void nested_svm_init(struct vcpu_svm *svm)
-{
-	svm->vmcb->control.exit_code   = SVM_EXIT_INIT;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
-
 static int svm_check_nested_events(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
@@ -1017,12 +1075,18 @@ static int svm_check_nested_events(struct kvm_vcpu *vcpu)
 			return -EBUSY;
 		if (!nested_exit_on_init(svm))
 			return 0;
-		nested_svm_init(svm);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_INIT);
 		return 0;
 	}
 
 	if (vcpu->arch.exception.pending) {
-		if (block_nested_events)
+		/*
+		 * Only a pending nested run can block a pending exception.
+		 * Otherwise an injected NMI/interrupt should either be
+		 * lost or delivered to the nested hypervisor in the EXITINTINFO
+		 * vmcb field, while delivering the pending exception.
+		 */
+		if (svm->nested.nested_run_pending)
                         return -EBUSY;
 		if (!nested_exit_on_exception(svm))
 			return 0;
@@ -1035,7 +1099,7 @@ static int svm_check_nested_events(struct kvm_vcpu *vcpu)
 			return -EBUSY;
 		if (!nested_exit_on_smi(svm))
 			return 0;
-		nested_svm_smi(svm);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_SMI);
 		return 0;
 	}
 
@@ -1044,7 +1108,7 @@ static int svm_check_nested_events(struct kvm_vcpu *vcpu)
 			return -EBUSY;
 		if (!nested_exit_on_nmi(svm))
 			return 0;
-		nested_svm_nmi(svm);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_NMI);
 		return 0;
 	}
 
@@ -1053,7 +1117,8 @@ static int svm_check_nested_events(struct kvm_vcpu *vcpu)
 			return -EBUSY;
 		if (!nested_exit_on_intr(svm))
 			return 0;
-		nested_svm_intr(svm);
+		trace_kvm_nested_intr_vmexit(svm->vmcb->save.rip);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_INTR);
 		return 0;
 	}
 
@@ -1072,8 +1137,8 @@ int nested_svm_exit_special(struct vcpu_svm *svm)
 	case SVM_EXIT_EXCP_BASE ... SVM_EXIT_EXCP_BASE + 0x1f: {
 		u32 excp_bits = 1 << (exit_code - SVM_EXIT_EXCP_BASE);
 
-		if (get_host_vmcb(svm)->control.intercepts[INTERCEPT_EXCEPTION] &
-				excp_bits)
+		if (svm->vmcb01.ptr->control.intercepts[INTERCEPT_EXCEPTION] &
+		    excp_bits)
 			return NESTED_EXIT_HOST;
 		else if (exit_code == SVM_EXIT_EXCP_BASE + PF_VECTOR &&
 			 svm->vcpu.arch.apf.host_apf_flags)
@@ -1137,10 +1202,9 @@ static int svm_get_nested_state(struct kvm_vcpu *vcpu,
 	if (copy_to_user(&user_vmcb->control, &svm->nested.ctl,
 			 sizeof(user_vmcb->control)))
 		return -EFAULT;
-	if (copy_to_user(&user_vmcb->save, &svm->nested.hsave->save,
+	if (copy_to_user(&user_vmcb->save, &svm->vmcb01.ptr->save,
 			 sizeof(user_vmcb->save)))
 		return -EFAULT;
-
 out:
 	return kvm_state.size;
 }
@@ -1150,7 +1214,6 @@ static int svm_set_nested_state(struct kvm_vcpu *vcpu,
 				struct kvm_nested_state *kvm_state)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
-	struct vmcb *hsave = svm->nested.hsave;
 	struct vmcb __user *user_vmcb = (struct vmcb __user *)
 		&user_kvm_nested_state->data.svm[0];
 	struct vmcb_control_area *ctl;
@@ -1195,8 +1258,8 @@ static int svm_set_nested_state(struct kvm_vcpu *vcpu,
 		return -EINVAL;
 
 	ret  = -ENOMEM;
-	ctl  = kzalloc(sizeof(*ctl),  GFP_KERNEL);
-	save = kzalloc(sizeof(*save), GFP_KERNEL);
+	ctl  = kzalloc(sizeof(*ctl),  GFP_KERNEL_ACCOUNT);
+	save = kzalloc(sizeof(*save), GFP_KERNEL_ACCOUNT);
 	if (!ctl || !save)
 		goto out_free;
 
@@ -1207,12 +1270,12 @@ static int svm_set_nested_state(struct kvm_vcpu *vcpu,
 		goto out_free;
 
 	ret = -EINVAL;
-	if (!nested_vmcb_check_controls(ctl))
+	if (!nested_vmcb_check_controls(vcpu, ctl))
 		goto out_free;
 
 	/*
 	 * Processor state contains L2 state.  Check that it is
-	 * valid for guest mode (see nested_vmcb_checks).
+	 * valid for guest mode (see nested_vmcb_check_save).
 	 */
 	cr0 = kvm_read_cr0(vcpu);
         if (((cr0 & X86_CR0_CD) == 0) && (cr0 & X86_CR0_NW))
@@ -1221,29 +1284,48 @@ static int svm_set_nested_state(struct kvm_vcpu *vcpu,
 	/*
 	 * Validate host state saved from before VMRUN (see
 	 * nested_svm_check_permissions).
-	 * TODO: validate reserved bits for all saved state.
 	 */
-	if (!(save->cr0 & X86_CR0_PG))
-		goto out_free;
-	if (!(save->efer & EFER_SVME))
+	if (!(save->cr0 & X86_CR0_PG) ||
+	    !(save->cr0 & X86_CR0_PE) ||
+	    (save->rflags & X86_EFLAGS_VM) ||
+	    !nested_vmcb_valid_sregs(vcpu, save))
 		goto out_free;
 
 	/*
-	 * All checks done, we can enter guest mode.  L1 control fields
-	 * come from the nested save state.  Guest state is already
-	 * in the registers, the save area of the nested state instead
-	 * contains saved L1 state.
+	 * All checks done, we can enter guest mode. Userspace provides
+	 * vmcb12.control, which will be combined with L1 and stored into
+	 * vmcb02, and the L1 save state which we store in vmcb01.
+	 * L2 registers if needed are moved from the current VMCB to VMCB02.
 	 */
 
 	svm->nested.nested_run_pending =
 		!!(kvm_state->flags & KVM_STATE_NESTED_RUN_PENDING);
 
-	copy_vmcb_control_area(&hsave->control, &svm->vmcb->control);
-	hsave->save = *save;
-
 	svm->nested.vmcb12_gpa = kvm_state->hdr.svm.vmcb_pa;
-	load_nested_vmcb_control(svm, ctl);
-	nested_prepare_vmcb_control(svm);
+	if (svm->current_vmcb == &svm->vmcb01)
+		svm->nested.vmcb02.ptr->save = svm->vmcb01.ptr->save;
+
+	svm->vmcb01.ptr->save.es = save->es;
+	svm->vmcb01.ptr->save.cs = save->cs;
+	svm->vmcb01.ptr->save.ss = save->ss;
+	svm->vmcb01.ptr->save.ds = save->ds;
+	svm->vmcb01.ptr->save.gdtr = save->gdtr;
+	svm->vmcb01.ptr->save.idtr = save->idtr;
+	svm->vmcb01.ptr->save.rflags = save->rflags | X86_EFLAGS_FIXED;
+	svm->vmcb01.ptr->save.efer = save->efer;
+	svm->vmcb01.ptr->save.cr0 = save->cr0;
+	svm->vmcb01.ptr->save.cr3 = save->cr3;
+	svm->vmcb01.ptr->save.cr4 = save->cr4;
+	svm->vmcb01.ptr->save.rax = save->rax;
+	svm->vmcb01.ptr->save.rsp = save->rsp;
+	svm->vmcb01.ptr->save.rip = save->rip;
+	svm->vmcb01.ptr->save.cpl = 0;
+
+	nested_load_control_from_vmcb12(svm, ctl);
+
+	svm_switch_vmcb(svm, &svm->nested.vmcb02);
+
+	nested_vmcb02_prepare_control(svm);
 
 	kvm_make_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 	ret = 0;
@@ -1254,8 +1336,31 @@ out_free:
 	return ret;
 }
 
+static bool svm_get_nested_state_pages(struct kvm_vcpu *vcpu)
+{
+	struct vcpu_svm *svm = to_svm(vcpu);
+
+	if (WARN_ON(!is_guest_mode(vcpu)))
+		return true;
+
+	if (nested_svm_load_cr3(&svm->vcpu, vcpu->arch.cr3,
+				nested_npt_enabled(svm)))
+		return false;
+
+	if (!nested_svm_vmrun_msrpm(svm)) {
+		vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
+		vcpu->run->internal.suberror =
+			KVM_INTERNAL_ERROR_EMULATION;
+		vcpu->run->internal.ndata = 0;
+		return false;
+	}
+
+	return true;
+}
+
 struct kvm_x86_nested_ops svm_nested_ops = {
 	.check_events = svm_check_nested_events,
+	.triple_fault = nested_svm_triple_fault,
 	.get_nested_state_pages = svm_get_nested_state_pages,
 	.get_state = svm_get_nested_state,
 	.set_state = svm_set_nested_state,
diff --git a/arch/x86/kvm/svm/sev.c b/arch/x86/kvm/svm/sev.c
index 415a49b8b8f8..1356ee095cd5 100644
--- a/arch/x86/kvm/svm/sev.c
+++ b/arch/x86/kvm/svm/sev.c
@@ -44,12 +44,25 @@
 #define MISC_CG_RES_SEV_ES MISC_CG_RES_TYPES
 #endif
 
+#ifdef CONFIG_KVM_AMD_SEV
+/* enable/disable SEV support */
+static bool sev_enabled = true;
+module_param_named(sev, sev_enabled, bool, 0444);
+
+/* enable/disable SEV-ES support */
+static bool sev_es_enabled = true;
+module_param_named(sev_es, sev_es_enabled, bool, 0444);
+#else
+#define sev_enabled false
+#define sev_es_enabled false
+#endif /* CONFIG_KVM_AMD_SEV */
+
 static u8 sev_enc_bit;
-static int sev_flush_asids(void);
 static DECLARE_RWSEM(sev_deactivate_lock);
 static DEFINE_MUTEX(sev_bitmap_lock);
 unsigned int max_sev_asid;
 static unsigned int min_sev_asid;
+static unsigned long sev_me_mask;
 static unsigned long *sev_asid_bitmap;
 static unsigned long *sev_reclaim_asid_bitmap;
 
@@ -61,9 +74,15 @@ struct enc_region {
 	unsigned long size;
 };
 
-static int sev_flush_asids(void)
+/* Called with the sev_bitmap_lock held, or on shutdown  */
+static int sev_flush_asids(int min_asid, int max_asid)
 {
-	int ret, error = 0;
+	int ret, pos, error = 0;
+
+	/* Check if there are any ASIDs to reclaim before performing a flush */
+	pos = find_next_bit(sev_reclaim_asid_bitmap, max_asid, min_asid);
+	if (pos >= max_asid)
+		return -EBUSY;
 
 	/*
 	 * DEACTIVATE will clear the WBINVD indicator causing DF_FLUSH to fail,
@@ -82,17 +101,15 @@ static int sev_flush_asids(void)
 	return ret;
 }
 
+static inline bool is_mirroring_enc_context(struct kvm *kvm)
+{
+	return !!to_kvm_svm(kvm)->sev_info.enc_context_owner;
+}
+
 /* Must be called with the sev_bitmap_lock held */
 static bool __sev_recycle_asids(int min_asid, int max_asid)
 {
-	int pos;
-
-	/* Check if there are any ASIDs to reclaim before performing a flush */
-	pos = find_next_bit(sev_reclaim_asid_bitmap, max_sev_asid, min_asid);
-	if (pos >= max_asid)
-		return false;
-
-	if (sev_flush_asids())
+	if (sev_flush_asids(min_asid, max_asid))
 		return false;
 
 	/* The flush process will flush all reclaimable SEV and SEV-ES ASIDs */
@@ -184,49 +201,41 @@ static void sev_asid_free(struct kvm_sev_info *sev)
 
 static void sev_unbind_asid(struct kvm *kvm, unsigned int handle)
 {
-	struct sev_data_decommission *decommission;
-	struct sev_data_deactivate *data;
+	struct sev_data_decommission decommission;
+	struct sev_data_deactivate deactivate;
 
 	if (!handle)
 		return;
 
-	data = kzalloc(sizeof(*data), GFP_KERNEL);
-	if (!data)
-		return;
-
-	/* deactivate handle */
-	data->handle = handle;
+	deactivate.handle = handle;
 
 	/* Guard DEACTIVATE against WBINVD/DF_FLUSH used in ASID recycling */
 	down_read(&sev_deactivate_lock);
-	sev_guest_deactivate(data, NULL);
+	sev_guest_deactivate(&deactivate, NULL);
 	up_read(&sev_deactivate_lock);
 
-	kfree(data);
-
-	decommission = kzalloc(sizeof(*decommission), GFP_KERNEL);
-	if (!decommission)
-		return;
-
 	/* decommission handle */
-	decommission->handle = handle;
-	sev_guest_decommission(decommission, NULL);
-
-	kfree(decommission);
+	decommission.handle = handle;
+	sev_guest_decommission(&decommission, NULL);
 }
 
 static int sev_guest_init(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	bool es_active = argp->id == KVM_SEV_ES_INIT;
 	int asid, ret;
 
+	if (kvm->created_vcpus)
+		return -EINVAL;
+
 	ret = -EBUSY;
 	if (unlikely(sev->active))
 		return ret;
 
+	sev->es_active = es_active;
 	asid = sev_asid_new(sev);
 	if (asid < 0)
-		return ret;
+		goto e_no_asid;
 	sev->asid = asid;
 
 	ret = sev_platform_init(&argp->error);
@@ -234,6 +243,7 @@ static int sev_guest_init(struct kvm *kvm, struct kvm_sev_cmd *argp)
 		goto e_free;
 
 	sev->active = true;
+	sev->asid = asid;
 	INIT_LIST_HEAD(&sev->regions_list);
 
 	return 0;
@@ -241,34 +251,21 @@ static int sev_guest_init(struct kvm *kvm, struct kvm_sev_cmd *argp)
 e_free:
 	sev_asid_free(sev);
 	sev->asid = 0;
+e_no_asid:
+	sev->es_active = false;
 	return ret;
 }
 
-static int sev_es_guest_init(struct kvm *kvm, struct kvm_sev_cmd *argp)
-{
-	if (!sev_es)
-		return -ENOTTY;
-
-	to_kvm_svm(kvm)->sev_info.es_active = true;
-
-	return sev_guest_init(kvm, argp);
-}
-
 static int sev_bind_asid(struct kvm *kvm, unsigned int handle, int *error)
 {
-	struct sev_data_activate *data;
+	struct sev_data_activate activate;
 	int asid = sev_get_asid(kvm);
 	int ret;
 
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		return -ENOMEM;
-
 	/* activate ASID on the given handle */
-	data->handle = handle;
-	data->asid   = asid;
-	ret = sev_guest_activate(data, error);
-	kfree(data);
+	activate.handle = handle;
+	activate.asid   = asid;
+	ret = sev_guest_activate(&activate, error);
 
 	return ret;
 }
@@ -298,7 +295,7 @@ static int sev_issue_cmd(struct kvm *kvm, int id, void *data, int *error)
 static int sev_launch_start(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
-	struct sev_data_launch_start *start;
+	struct sev_data_launch_start start;
 	struct kvm_sev_launch_start params;
 	void *dh_blob, *session_blob;
 	int *error = &argp->error;
@@ -310,20 +307,16 @@ static int sev_launch_start(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data, sizeof(params)))
 		return -EFAULT;
 
-	start = kzalloc(sizeof(*start), GFP_KERNEL_ACCOUNT);
-	if (!start)
-		return -ENOMEM;
+	memset(&start, 0, sizeof(start));
 
 	dh_blob = NULL;
 	if (params.dh_uaddr) {
 		dh_blob = psp_copy_user_blob(params.dh_uaddr, params.dh_len);
-		if (IS_ERR(dh_blob)) {
-			ret = PTR_ERR(dh_blob);
-			goto e_free;
-		}
+		if (IS_ERR(dh_blob))
+			return PTR_ERR(dh_blob);
 
-		start->dh_cert_address = __sme_set(__pa(dh_blob));
-		start->dh_cert_len = params.dh_len;
+		start.dh_cert_address = __sme_set(__pa(dh_blob));
+		start.dh_cert_len = params.dh_len;
 	}
 
 	session_blob = NULL;
@@ -334,40 +327,38 @@ static int sev_launch_start(struct kvm *kvm, struct kvm_sev_cmd *argp)
 			goto e_free_dh;
 		}
 
-		start->session_address = __sme_set(__pa(session_blob));
-		start->session_len = params.session_len;
+		start.session_address = __sme_set(__pa(session_blob));
+		start.session_len = params.session_len;
 	}
 
-	start->handle = params.handle;
-	start->policy = params.policy;
+	start.handle = params.handle;
+	start.policy = params.policy;
 
 	/* create memory encryption context */
-	ret = __sev_issue_cmd(argp->sev_fd, SEV_CMD_LAUNCH_START, start, error);
+	ret = __sev_issue_cmd(argp->sev_fd, SEV_CMD_LAUNCH_START, &start, error);
 	if (ret)
 		goto e_free_session;
 
 	/* Bind ASID to this guest */
-	ret = sev_bind_asid(kvm, start->handle, error);
+	ret = sev_bind_asid(kvm, start.handle, error);
 	if (ret)
 		goto e_free_session;
 
 	/* return handle to userspace */
-	params.handle = start->handle;
+	params.handle = start.handle;
 	if (copy_to_user((void __user *)(uintptr_t)argp->data, &params, sizeof(params))) {
-		sev_unbind_asid(kvm, start->handle);
+		sev_unbind_asid(kvm, start.handle);
 		ret = -EFAULT;
 		goto e_free_session;
 	}
 
-	sev->handle = start->handle;
+	sev->handle = start.handle;
 	sev->fd = argp->sev_fd;
 
 e_free_session:
 	kfree(session_blob);
 e_free_dh:
 	kfree(dh_blob);
-e_free:
-	kfree(start);
 	return ret;
 }
 
@@ -486,7 +477,7 @@ static int sev_launch_update_data(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	unsigned long vaddr, vaddr_end, next_vaddr, npages, pages, size, i;
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
 	struct kvm_sev_launch_update_data params;
-	struct sev_data_launch_update_data *data;
+	struct sev_data_launch_update_data data;
 	struct page **inpages;
 	int ret;
 
@@ -496,20 +487,14 @@ static int sev_launch_update_data(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data, sizeof(params)))
 		return -EFAULT;
 
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		return -ENOMEM;
-
 	vaddr = params.uaddr;
 	size = params.len;
 	vaddr_end = vaddr + size;
 
 	/* Lock the user memory. */
 	inpages = sev_pin_memory(kvm, vaddr, size, &npages, 1);
-	if (IS_ERR(inpages)) {
-		ret = PTR_ERR(inpages);
-		goto e_free;
-	}
+	if (IS_ERR(inpages))
+		return PTR_ERR(inpages);
 
 	/*
 	 * Flush (on non-coherent CPUs) before LAUNCH_UPDATE encrypts pages in
@@ -517,6 +502,9 @@ static int sev_launch_update_data(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	 */
 	sev_clflush_pages(inpages, npages);
 
+	data.reserved = 0;
+	data.handle = sev->handle;
+
 	for (i = 0; vaddr < vaddr_end; vaddr = next_vaddr, i += pages) {
 		int offset, len;
 
@@ -531,10 +519,9 @@ static int sev_launch_update_data(struct kvm *kvm, struct kvm_sev_cmd *argp)
 
 		len = min_t(size_t, ((pages * PAGE_SIZE) - offset), size);
 
-		data->handle = sev->handle;
-		data->len = len;
-		data->address = __sme_page_pa(inpages[i]) + offset;
-		ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_DATA, data, &argp->error);
+		data.len = len;
+		data.address = __sme_page_pa(inpages[i]) + offset;
+		ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_DATA, &data, &argp->error);
 		if (ret)
 			goto e_unpin;
 
@@ -550,8 +537,6 @@ e_unpin:
 	}
 	/* unlock the user pages */
 	sev_unpin_memory(kvm, inpages, npages);
-e_free:
-	kfree(data);
 	return ret;
 }
 
@@ -603,23 +588,22 @@ static int sev_es_sync_vmsa(struct vcpu_svm *svm)
 static int sev_launch_update_vmsa(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
-	struct sev_data_launch_update_vmsa *vmsa;
+	struct sev_data_launch_update_vmsa vmsa;
+	struct kvm_vcpu *vcpu;
 	int i, ret;
 
 	if (!sev_es_guest(kvm))
 		return -ENOTTY;
 
-	vmsa = kzalloc(sizeof(*vmsa), GFP_KERNEL);
-	if (!vmsa)
-		return -ENOMEM;
+	vmsa.reserved = 0;
 
-	for (i = 0; i < kvm->created_vcpus; i++) {
-		struct vcpu_svm *svm = to_svm(kvm->vcpus[i]);
+	kvm_for_each_vcpu(i, vcpu, kvm) {
+		struct vcpu_svm *svm = to_svm(vcpu);
 
 		/* Perform some pre-encryption checks against the VMSA */
 		ret = sev_es_sync_vmsa(svm);
 		if (ret)
-			goto e_free;
+			return ret;
 
 		/*
 		 * The LAUNCH_UPDATE_VMSA command will perform in-place
@@ -629,27 +613,25 @@ static int sev_launch_update_vmsa(struct kvm *kvm, struct kvm_sev_cmd *argp)
 		 */
 		clflush_cache_range(svm->vmsa, PAGE_SIZE);
 
-		vmsa->handle = sev->handle;
-		vmsa->address = __sme_pa(svm->vmsa);
-		vmsa->len = PAGE_SIZE;
-		ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_VMSA, vmsa,
+		vmsa.handle = sev->handle;
+		vmsa.address = __sme_pa(svm->vmsa);
+		vmsa.len = PAGE_SIZE;
+		ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_VMSA, &vmsa,
 				    &argp->error);
 		if (ret)
-			goto e_free;
+			return ret;
 
 		svm->vcpu.arch.guest_state_protected = true;
 	}
 
-e_free:
-	kfree(vmsa);
-	return ret;
+	return 0;
 }
 
 static int sev_launch_measure(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	void __user *measure = (void __user *)(uintptr_t)argp->data;
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
-	struct sev_data_launch_measure *data;
+	struct sev_data_launch_measure data;
 	struct kvm_sev_launch_measure params;
 	void __user *p = NULL;
 	void *blob = NULL;
@@ -661,9 +643,7 @@ static int sev_launch_measure(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	if (copy_from_user(&params, measure, sizeof(params)))
 		return -EFAULT;
 
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		return -ENOMEM;
+	memset(&data, 0, sizeof(data));
 
 	/* User wants to query the blob length */
 	if (!params.len)
@@ -671,23 +651,20 @@ static int sev_launch_measure(struct kvm *kvm, struct kvm_sev_cmd *argp)
 
 	p = (void __user *)(uintptr_t)params.uaddr;
 	if (p) {
-		if (params.len > SEV_FW_BLOB_MAX_SIZE) {
-			ret = -EINVAL;
-			goto e_free;
-		}
+		if (params.len > SEV_FW_BLOB_MAX_SIZE)
+			return -EINVAL;
 
-		ret = -ENOMEM;
-		blob = kmalloc(params.len, GFP_KERNEL);
+		blob = kmalloc(params.len, GFP_KERNEL_ACCOUNT);
 		if (!blob)
-			goto e_free;
+			return -ENOMEM;
 
-		data->address = __psp_pa(blob);
-		data->len = params.len;
+		data.address = __psp_pa(blob);
+		data.len = params.len;
 	}
 
 cmd:
-	data->handle = sev->handle;
-	ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_MEASURE, data, &argp->error);
+	data.handle = sev->handle;
+	ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_MEASURE, &data, &argp->error);
 
 	/*
 	 * If we query the session length, FW responded with expected data.
@@ -704,63 +681,50 @@ cmd:
 	}
 
 done:
-	params.len = data->len;
+	params.len = data.len;
 	if (copy_to_user(measure, &params, sizeof(params)))
 		ret = -EFAULT;
 e_free_blob:
 	kfree(blob);
-e_free:
-	kfree(data);
 	return ret;
 }
 
 static int sev_launch_finish(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
-	struct sev_data_launch_finish *data;
-	int ret;
+	struct sev_data_launch_finish data;
 
 	if (!sev_guest(kvm))
 		return -ENOTTY;
 
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		return -ENOMEM;
-
-	data->handle = sev->handle;
-	ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_FINISH, data, &argp->error);
-
-	kfree(data);
-	return ret;
+	data.handle = sev->handle;
+	return sev_issue_cmd(kvm, SEV_CMD_LAUNCH_FINISH, &data, &argp->error);
 }
 
 static int sev_guest_status(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
 	struct kvm_sev_guest_status params;
-	struct sev_data_guest_status *data;
+	struct sev_data_guest_status data;
 	int ret;
 
 	if (!sev_guest(kvm))
 		return -ENOTTY;
 
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		return -ENOMEM;
+	memset(&data, 0, sizeof(data));
 
-	data->handle = sev->handle;
-	ret = sev_issue_cmd(kvm, SEV_CMD_GUEST_STATUS, data, &argp->error);
+	data.handle = sev->handle;
+	ret = sev_issue_cmd(kvm, SEV_CMD_GUEST_STATUS, &data, &argp->error);
 	if (ret)
-		goto e_free;
+		return ret;
 
-	params.policy = data->policy;
-	params.state = data->state;
-	params.handle = data->handle;
+	params.policy = data.policy;
+	params.state = data.state;
+	params.handle = data.handle;
 
 	if (copy_to_user((void __user *)(uintptr_t)argp->data, &params, sizeof(params)))
 		ret = -EFAULT;
-e_free:
-	kfree(data);
+
 	return ret;
 }
 
@@ -769,23 +733,17 @@ static int __sev_issue_dbg_cmd(struct kvm *kvm, unsigned long src,
 			       int *error, bool enc)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
-	struct sev_data_dbg *data;
-	int ret;
-
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		return -ENOMEM;
+	struct sev_data_dbg data;
 
-	data->handle = sev->handle;
-	data->dst_addr = dst;
-	data->src_addr = src;
-	data->len = size;
+	data.reserved = 0;
+	data.handle = sev->handle;
+	data.dst_addr = dst;
+	data.src_addr = src;
+	data.len = size;
 
-	ret = sev_issue_cmd(kvm,
-			    enc ? SEV_CMD_DBG_ENCRYPT : SEV_CMD_DBG_DECRYPT,
-			    data, error);
-	kfree(data);
-	return ret;
+	return sev_issue_cmd(kvm,
+			     enc ? SEV_CMD_DBG_ENCRYPT : SEV_CMD_DBG_DECRYPT,
+			     &data, error);
 }
 
 static int __sev_dbg_decrypt(struct kvm *kvm, unsigned long src_paddr,
@@ -1005,7 +963,7 @@ err:
 static int sev_launch_secret(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
-	struct sev_data_launch_secret *data;
+	struct sev_data_launch_secret data;
 	struct kvm_sev_launch_secret params;
 	struct page **pages;
 	void *blob, *hdr;
@@ -1037,41 +995,36 @@ static int sev_launch_secret(struct kvm *kvm, struct kvm_sev_cmd *argp)
 		goto e_unpin_memory;
 	}
 
-	ret = -ENOMEM;
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		goto e_unpin_memory;
+	memset(&data, 0, sizeof(data));
 
 	offset = params.guest_uaddr & (PAGE_SIZE - 1);
-	data->guest_address = __sme_page_pa(pages[0]) + offset;
-	data->guest_len = params.guest_len;
+	data.guest_address = __sme_page_pa(pages[0]) + offset;
+	data.guest_len = params.guest_len;
 
 	blob = psp_copy_user_blob(params.trans_uaddr, params.trans_len);
 	if (IS_ERR(blob)) {
 		ret = PTR_ERR(blob);
-		goto e_free;
+		goto e_unpin_memory;
 	}
 
-	data->trans_address = __psp_pa(blob);
-	data->trans_len = params.trans_len;
+	data.trans_address = __psp_pa(blob);
+	data.trans_len = params.trans_len;
 
 	hdr = psp_copy_user_blob(params.hdr_uaddr, params.hdr_len);
 	if (IS_ERR(hdr)) {
 		ret = PTR_ERR(hdr);
 		goto e_free_blob;
 	}
-	data->hdr_address = __psp_pa(hdr);
-	data->hdr_len = params.hdr_len;
+	data.hdr_address = __psp_pa(hdr);
+	data.hdr_len = params.hdr_len;
 
-	data->handle = sev->handle;
-	ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_SECRET, data, &argp->error);
+	data.handle = sev->handle;
+	ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_SECRET, &data, &argp->error);
 
 	kfree(hdr);
 
 e_free_blob:
 	kfree(blob);
-e_free:
-	kfree(data);
 e_unpin_memory:
 	/* content of memory is updated, mark pages dirty */
 	for (i = 0; i < n; i++) {
@@ -1086,7 +1039,7 @@ static int sev_get_attestation_report(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	void __user *report = (void __user *)(uintptr_t)argp->data;
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
-	struct sev_data_attestation_report *data;
+	struct sev_data_attestation_report data;
 	struct kvm_sev_attestation_report params;
 	void __user *p;
 	void *blob = NULL;
@@ -1098,9 +1051,7 @@ static int sev_get_attestation_report(struct kvm *kvm, struct kvm_sev_cmd *argp)
 	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data, sizeof(params)))
 		return -EFAULT;
 
-	data = kzalloc(sizeof(*data), GFP_KERNEL_ACCOUNT);
-	if (!data)
-		return -ENOMEM;
+	memset(&data, 0, sizeof(data));
 
 	/* User wants to query the blob length */
 	if (!params.len)
@@ -1108,23 +1059,20 @@ static int sev_get_attestation_report(struct kvm *kvm, struct kvm_sev_cmd *argp)
 
 	p = (void __user *)(uintptr_t)params.uaddr;
 	if (p) {
-		if (params.len > SEV_FW_BLOB_MAX_SIZE) {
-			ret = -EINVAL;
-			goto e_free;
-		}
+		if (params.len > SEV_FW_BLOB_MAX_SIZE)
+			return -EINVAL;
 
-		ret = -ENOMEM;
-		blob = kmalloc(params.len, GFP_KERNEL);
+		blob = kmalloc(params.len, GFP_KERNEL_ACCOUNT);
 		if (!blob)
-			goto e_free;
+			return -ENOMEM;
 
-		data->address = __psp_pa(blob);
-		data->len = params.len;
-		memcpy(data->mnonce, params.mnonce, sizeof(params.mnonce));
+		data.address = __psp_pa(blob);
+		data.len = params.len;
+		memcpy(data.mnonce, params.mnonce, sizeof(params.mnonce));
 	}
 cmd:
-	data->handle = sev->handle;
-	ret = sev_issue_cmd(kvm, SEV_CMD_ATTESTATION_REPORT, data, &argp->error);
+	data.handle = sev->handle;
+	ret = sev_issue_cmd(kvm, SEV_CMD_ATTESTATION_REPORT, &data, &argp->error);
 	/*
 	 * If we query the session length, FW responded with expected data.
 	 */
@@ -1140,22 +1088,417 @@ cmd:
 	}
 
 done:
-	params.len = data->len;
+	params.len = data.len;
 	if (copy_to_user(report, &params, sizeof(params)))
 		ret = -EFAULT;
 e_free_blob:
 	kfree(blob);
-e_free:
-	kfree(data);
 	return ret;
 }
 
+/* Userspace wants to query session length. */
+static int
+__sev_send_start_query_session_length(struct kvm *kvm, struct kvm_sev_cmd *argp,
+				      struct kvm_sev_send_start *params)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_send_start data;
+	int ret;
+
+	data.handle = sev->handle;
+	ret = sev_issue_cmd(kvm, SEV_CMD_SEND_START, &data, &argp->error);
+	if (ret < 0)
+		return ret;
+
+	params->session_len = data.session_len;
+	if (copy_to_user((void __user *)(uintptr_t)argp->data, params,
+				sizeof(struct kvm_sev_send_start)))
+		ret = -EFAULT;
+
+	return ret;
+}
+
+static int sev_send_start(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_send_start data;
+	struct kvm_sev_send_start params;
+	void *amd_certs, *session_data;
+	void *pdh_cert, *plat_certs;
+	int ret;
+
+	if (!sev_guest(kvm))
+		return -ENOTTY;
+
+	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data,
+				sizeof(struct kvm_sev_send_start)))
+		return -EFAULT;
+
+	/* if session_len is zero, userspace wants to query the session length */
+	if (!params.session_len)
+		return __sev_send_start_query_session_length(kvm, argp,
+				&params);
+
+	/* some sanity checks */
+	if (!params.pdh_cert_uaddr || !params.pdh_cert_len ||
+	    !params.session_uaddr || params.session_len > SEV_FW_BLOB_MAX_SIZE)
+		return -EINVAL;
+
+	/* allocate the memory to hold the session data blob */
+	session_data = kmalloc(params.session_len, GFP_KERNEL_ACCOUNT);
+	if (!session_data)
+		return -ENOMEM;
+
+	/* copy the certificate blobs from userspace */
+	pdh_cert = psp_copy_user_blob(params.pdh_cert_uaddr,
+				params.pdh_cert_len);
+	if (IS_ERR(pdh_cert)) {
+		ret = PTR_ERR(pdh_cert);
+		goto e_free_session;
+	}
+
+	plat_certs = psp_copy_user_blob(params.plat_certs_uaddr,
+				params.plat_certs_len);
+	if (IS_ERR(plat_certs)) {
+		ret = PTR_ERR(plat_certs);
+		goto e_free_pdh;
+	}
+
+	amd_certs = psp_copy_user_blob(params.amd_certs_uaddr,
+				params.amd_certs_len);
+	if (IS_ERR(amd_certs)) {
+		ret = PTR_ERR(amd_certs);
+		goto e_free_plat_cert;
+	}
+
+	/* populate the FW SEND_START field with system physical address */
+	memset(&data, 0, sizeof(data));
+	data.pdh_cert_address = __psp_pa(pdh_cert);
+	data.pdh_cert_len = params.pdh_cert_len;
+	data.plat_certs_address = __psp_pa(plat_certs);
+	data.plat_certs_len = params.plat_certs_len;
+	data.amd_certs_address = __psp_pa(amd_certs);
+	data.amd_certs_len = params.amd_certs_len;
+	data.session_address = __psp_pa(session_data);
+	data.session_len = params.session_len;
+	data.handle = sev->handle;
+
+	ret = sev_issue_cmd(kvm, SEV_CMD_SEND_START, &data, &argp->error);
+
+	if (!ret && copy_to_user((void __user *)(uintptr_t)params.session_uaddr,
+			session_data, params.session_len)) {
+		ret = -EFAULT;
+		goto e_free_amd_cert;
+	}
+
+	params.policy = data.policy;
+	params.session_len = data.session_len;
+	if (copy_to_user((void __user *)(uintptr_t)argp->data, &params,
+				sizeof(struct kvm_sev_send_start)))
+		ret = -EFAULT;
+
+e_free_amd_cert:
+	kfree(amd_certs);
+e_free_plat_cert:
+	kfree(plat_certs);
+e_free_pdh:
+	kfree(pdh_cert);
+e_free_session:
+	kfree(session_data);
+	return ret;
+}
+
+/* Userspace wants to query either header or trans length. */
+static int
+__sev_send_update_data_query_lengths(struct kvm *kvm, struct kvm_sev_cmd *argp,
+				     struct kvm_sev_send_update_data *params)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_send_update_data data;
+	int ret;
+
+	data.handle = sev->handle;
+	ret = sev_issue_cmd(kvm, SEV_CMD_SEND_UPDATE_DATA, &data, &argp->error);
+	if (ret < 0)
+		return ret;
+
+	params->hdr_len = data.hdr_len;
+	params->trans_len = data.trans_len;
+
+	if (copy_to_user((void __user *)(uintptr_t)argp->data, params,
+			 sizeof(struct kvm_sev_send_update_data)))
+		ret = -EFAULT;
+
+	return ret;
+}
+
+static int sev_send_update_data(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_send_update_data data;
+	struct kvm_sev_send_update_data params;
+	void *hdr, *trans_data;
+	struct page **guest_page;
+	unsigned long n;
+	int ret, offset;
+
+	if (!sev_guest(kvm))
+		return -ENOTTY;
+
+	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data,
+			sizeof(struct kvm_sev_send_update_data)))
+		return -EFAULT;
+
+	/* userspace wants to query either header or trans length */
+	if (!params.trans_len || !params.hdr_len)
+		return __sev_send_update_data_query_lengths(kvm, argp, &params);
+
+	if (!params.trans_uaddr || !params.guest_uaddr ||
+	    !params.guest_len || !params.hdr_uaddr)
+		return -EINVAL;
+
+	/* Check if we are crossing the page boundary */
+	offset = params.guest_uaddr & (PAGE_SIZE - 1);
+	if ((params.guest_len + offset > PAGE_SIZE))
+		return -EINVAL;
+
+	/* Pin guest memory */
+	guest_page = sev_pin_memory(kvm, params.guest_uaddr & PAGE_MASK,
+				    PAGE_SIZE, &n, 0);
+	if (!guest_page)
+		return -EFAULT;
+
+	/* allocate memory for header and transport buffer */
+	ret = -ENOMEM;
+	hdr = kmalloc(params.hdr_len, GFP_KERNEL_ACCOUNT);
+	if (!hdr)
+		goto e_unpin;
+
+	trans_data = kmalloc(params.trans_len, GFP_KERNEL_ACCOUNT);
+	if (!trans_data)
+		goto e_free_hdr;
+
+	memset(&data, 0, sizeof(data));
+	data.hdr_address = __psp_pa(hdr);
+	data.hdr_len = params.hdr_len;
+	data.trans_address = __psp_pa(trans_data);
+	data.trans_len = params.trans_len;
+
+	/* The SEND_UPDATE_DATA command requires C-bit to be always set. */
+	data.guest_address = (page_to_pfn(guest_page[0]) << PAGE_SHIFT) + offset;
+	data.guest_address |= sev_me_mask;
+	data.guest_len = params.guest_len;
+	data.handle = sev->handle;
+
+	ret = sev_issue_cmd(kvm, SEV_CMD_SEND_UPDATE_DATA, &data, &argp->error);
+
+	if (ret)
+		goto e_free_trans_data;
+
+	/* copy transport buffer to user space */
+	if (copy_to_user((void __user *)(uintptr_t)params.trans_uaddr,
+			 trans_data, params.trans_len)) {
+		ret = -EFAULT;
+		goto e_free_trans_data;
+	}
+
+	/* Copy packet header to userspace. */
+	ret = copy_to_user((void __user *)(uintptr_t)params.hdr_uaddr, hdr,
+				params.hdr_len);
+
+e_free_trans_data:
+	kfree(trans_data);
+e_free_hdr:
+	kfree(hdr);
+e_unpin:
+	sev_unpin_memory(kvm, guest_page, n);
+
+	return ret;
+}
+
+static int sev_send_finish(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_send_finish data;
+
+	if (!sev_guest(kvm))
+		return -ENOTTY;
+
+	data.handle = sev->handle;
+	return sev_issue_cmd(kvm, SEV_CMD_SEND_FINISH, &data, &argp->error);
+}
+
+static int sev_send_cancel(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_send_cancel data;
+
+	if (!sev_guest(kvm))
+		return -ENOTTY;
+
+	data.handle = sev->handle;
+	return sev_issue_cmd(kvm, SEV_CMD_SEND_CANCEL, &data, &argp->error);
+}
+
+static int sev_receive_start(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_receive_start start;
+	struct kvm_sev_receive_start params;
+	int *error = &argp->error;
+	void *session_data;
+	void *pdh_data;
+	int ret;
+
+	if (!sev_guest(kvm))
+		return -ENOTTY;
+
+	/* Get parameter from the userspace */
+	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data,
+			sizeof(struct kvm_sev_receive_start)))
+		return -EFAULT;
+
+	/* some sanity checks */
+	if (!params.pdh_uaddr || !params.pdh_len ||
+	    !params.session_uaddr || !params.session_len)
+		return -EINVAL;
+
+	pdh_data = psp_copy_user_blob(params.pdh_uaddr, params.pdh_len);
+	if (IS_ERR(pdh_data))
+		return PTR_ERR(pdh_data);
+
+	session_data = psp_copy_user_blob(params.session_uaddr,
+			params.session_len);
+	if (IS_ERR(session_data)) {
+		ret = PTR_ERR(session_data);
+		goto e_free_pdh;
+	}
+
+	memset(&start, 0, sizeof(start));
+	start.handle = params.handle;
+	start.policy = params.policy;
+	start.pdh_cert_address = __psp_pa(pdh_data);
+	start.pdh_cert_len = params.pdh_len;
+	start.session_address = __psp_pa(session_data);
+	start.session_len = params.session_len;
+
+	/* create memory encryption context */
+	ret = __sev_issue_cmd(argp->sev_fd, SEV_CMD_RECEIVE_START, &start,
+				error);
+	if (ret)
+		goto e_free_session;
+
+	/* Bind ASID to this guest */
+	ret = sev_bind_asid(kvm, start.handle, error);
+	if (ret)
+		goto e_free_session;
+
+	params.handle = start.handle;
+	if (copy_to_user((void __user *)(uintptr_t)argp->data,
+			 &params, sizeof(struct kvm_sev_receive_start))) {
+		ret = -EFAULT;
+		sev_unbind_asid(kvm, start.handle);
+		goto e_free_session;
+	}
+
+    	sev->handle = start.handle;
+	sev->fd = argp->sev_fd;
+
+e_free_session:
+	kfree(session_data);
+e_free_pdh:
+	kfree(pdh_data);
+
+	return ret;
+}
+
+static int sev_receive_update_data(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct kvm_sev_receive_update_data params;
+	struct sev_data_receive_update_data data;
+	void *hdr = NULL, *trans = NULL;
+	struct page **guest_page;
+	unsigned long n;
+	int ret, offset;
+
+	if (!sev_guest(kvm))
+		return -EINVAL;
+
+	if (copy_from_user(&params, (void __user *)(uintptr_t)argp->data,
+			sizeof(struct kvm_sev_receive_update_data)))
+		return -EFAULT;
+
+	if (!params.hdr_uaddr || !params.hdr_len ||
+	    !params.guest_uaddr || !params.guest_len ||
+	    !params.trans_uaddr || !params.trans_len)
+		return -EINVAL;
+
+	/* Check if we are crossing the page boundary */
+	offset = params.guest_uaddr & (PAGE_SIZE - 1);
+	if ((params.guest_len + offset > PAGE_SIZE))
+		return -EINVAL;
+
+	hdr = psp_copy_user_blob(params.hdr_uaddr, params.hdr_len);
+	if (IS_ERR(hdr))
+		return PTR_ERR(hdr);
+
+	trans = psp_copy_user_blob(params.trans_uaddr, params.trans_len);
+	if (IS_ERR(trans)) {
+		ret = PTR_ERR(trans);
+		goto e_free_hdr;
+	}
+
+	memset(&data, 0, sizeof(data));
+	data.hdr_address = __psp_pa(hdr);
+	data.hdr_len = params.hdr_len;
+	data.trans_address = __psp_pa(trans);
+	data.trans_len = params.trans_len;
+
+	/* Pin guest memory */
+	ret = -EFAULT;
+	guest_page = sev_pin_memory(kvm, params.guest_uaddr & PAGE_MASK,
+				    PAGE_SIZE, &n, 0);
+	if (!guest_page)
+		goto e_free_trans;
+
+	/* The RECEIVE_UPDATE_DATA command requires C-bit to be always set. */
+	data.guest_address = (page_to_pfn(guest_page[0]) << PAGE_SHIFT) + offset;
+	data.guest_address |= sev_me_mask;
+	data.guest_len = params.guest_len;
+	data.handle = sev->handle;
+
+	ret = sev_issue_cmd(kvm, SEV_CMD_RECEIVE_UPDATE_DATA, &data,
+				&argp->error);
+
+	sev_unpin_memory(kvm, guest_page, n);
+
+e_free_trans:
+	kfree(trans);
+e_free_hdr:
+	kfree(hdr);
+
+	return ret;
+}
+
+static int sev_receive_finish(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
+	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	struct sev_data_receive_finish data;
+
+	if (!sev_guest(kvm))
+		return -ENOTTY;
+
+	data.handle = sev->handle;
+	return sev_issue_cmd(kvm, SEV_CMD_RECEIVE_FINISH, &data, &argp->error);
+}
+
 int svm_mem_enc_op(struct kvm *kvm, void __user *argp)
 {
 	struct kvm_sev_cmd sev_cmd;
 	int r;
 
-	if (!svm_sev_enabled() || !sev)
+	if (!sev_enabled)
 		return -ENOTTY;
 
 	if (!argp)
@@ -1166,13 +1509,22 @@ int svm_mem_enc_op(struct kvm *kvm, void __user *argp)
 
 	mutex_lock(&kvm->lock);
 
+	/* enc_context_owner handles all memory enc operations */
+	if (is_mirroring_enc_context(kvm)) {
+		r = -EINVAL;
+		goto out;
+	}
+
 	switch (sev_cmd.id) {
+	case KVM_SEV_ES_INIT:
+		if (!sev_es_enabled) {
+			r = -ENOTTY;
+			goto out;
+		}
+		fallthrough;
 	case KVM_SEV_INIT:
 		r = sev_guest_init(kvm, &sev_cmd);
 		break;
-	case KVM_SEV_ES_INIT:
-		r = sev_es_guest_init(kvm, &sev_cmd);
-		break;
 	case KVM_SEV_LAUNCH_START:
 		r = sev_launch_start(kvm, &sev_cmd);
 		break;
@@ -1203,6 +1555,27 @@ int svm_mem_enc_op(struct kvm *kvm, void __user *argp)
 	case KVM_SEV_GET_ATTESTATION_REPORT:
 		r = sev_get_attestation_report(kvm, &sev_cmd);
 		break;
+	case KVM_SEV_SEND_START:
+		r = sev_send_start(kvm, &sev_cmd);
+		break;
+	case KVM_SEV_SEND_UPDATE_DATA:
+		r = sev_send_update_data(kvm, &sev_cmd);
+		break;
+	case KVM_SEV_SEND_FINISH:
+		r = sev_send_finish(kvm, &sev_cmd);
+		break;
+	case KVM_SEV_SEND_CANCEL:
+		r = sev_send_cancel(kvm, &sev_cmd);
+		break;
+	case KVM_SEV_RECEIVE_START:
+		r = sev_receive_start(kvm, &sev_cmd);
+		break;
+	case KVM_SEV_RECEIVE_UPDATE_DATA:
+		r = sev_receive_update_data(kvm, &sev_cmd);
+		break;
+	case KVM_SEV_RECEIVE_FINISH:
+		r = sev_receive_finish(kvm, &sev_cmd);
+		break;
 	default:
 		r = -EINVAL;
 		goto out;
@@ -1226,6 +1599,10 @@ int svm_register_enc_region(struct kvm *kvm,
 	if (!sev_guest(kvm))
 		return -ENOTTY;
 
+	/* If kvm is mirroring encryption context it isn't responsible for it */
+	if (is_mirroring_enc_context(kvm))
+		return -EINVAL;
+
 	if (range->addr > ULONG_MAX || range->size > ULONG_MAX)
 		return -EINVAL;
 
@@ -1292,6 +1669,10 @@ int svm_unregister_enc_region(struct kvm *kvm,
 	struct enc_region *region;
 	int ret;
 
+	/* If kvm is mirroring encryption context it isn't responsible for it */
+	if (is_mirroring_enc_context(kvm))
+		return -EINVAL;
+
 	mutex_lock(&kvm->lock);
 
 	if (!sev_guest(kvm)) {
@@ -1322,6 +1703,71 @@ failed:
 	return ret;
 }
 
+int svm_vm_copy_asid_from(struct kvm *kvm, unsigned int source_fd)
+{
+	struct file *source_kvm_file;
+	struct kvm *source_kvm;
+	struct kvm_sev_info *mirror_sev;
+	unsigned int asid;
+	int ret;
+
+	source_kvm_file = fget(source_fd);
+	if (!file_is_kvm(source_kvm_file)) {
+		ret = -EBADF;
+		goto e_source_put;
+	}
+
+	source_kvm = source_kvm_file->private_data;
+	mutex_lock(&source_kvm->lock);
+
+	if (!sev_guest(source_kvm)) {
+		ret = -EINVAL;
+		goto e_source_unlock;
+	}
+
+	/* Mirrors of mirrors should work, but let's not get silly */
+	if (is_mirroring_enc_context(source_kvm) || source_kvm == kvm) {
+		ret = -EINVAL;
+		goto e_source_unlock;
+	}
+
+	asid = to_kvm_svm(source_kvm)->sev_info.asid;
+
+	/*
+	 * The mirror kvm holds an enc_context_owner ref so its asid can't
+	 * disappear until we're done with it
+	 */
+	kvm_get_kvm(source_kvm);
+
+	fput(source_kvm_file);
+	mutex_unlock(&source_kvm->lock);
+	mutex_lock(&kvm->lock);
+
+	if (sev_guest(kvm)) {
+		ret = -EINVAL;
+		goto e_mirror_unlock;
+	}
+
+	/* Set enc_context_owner and copy its encryption context over */
+	mirror_sev = &to_kvm_svm(kvm)->sev_info;
+	mirror_sev->enc_context_owner = source_kvm;
+	mirror_sev->asid = asid;
+	mirror_sev->active = true;
+
+	mutex_unlock(&kvm->lock);
+	return 0;
+
+e_mirror_unlock:
+	mutex_unlock(&kvm->lock);
+	kvm_put_kvm(source_kvm);
+	return ret;
+e_source_unlock:
+	mutex_unlock(&source_kvm->lock);
+e_source_put:
+	fput(source_kvm_file);
+	return ret;
+}
+
 void sev_vm_destroy(struct kvm *kvm)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
@@ -1331,6 +1777,12 @@ void sev_vm_destroy(struct kvm *kvm)
 	if (!sev_guest(kvm))
 		return;
 
+	/* If this is a mirror_kvm release the enc_context_owner and skip sev cleanup */
+	if (is_mirroring_enc_context(kvm)) {
+		kvm_put_kvm(sev->enc_context_owner);
+		return;
+	}
+
 	mutex_lock(&kvm->lock);
 
 	/*
@@ -1358,12 +1810,24 @@ void sev_vm_destroy(struct kvm *kvm)
 	sev_asid_free(sev);
 }
 
+void __init sev_set_cpu_caps(void)
+{
+	if (!sev_enabled)
+		kvm_cpu_cap_clear(X86_FEATURE_SEV);
+	if (!sev_es_enabled)
+		kvm_cpu_cap_clear(X86_FEATURE_SEV_ES);
+}
+
 void __init sev_hardware_setup(void)
 {
+#ifdef CONFIG_KVM_AMD_SEV
 	unsigned int eax, ebx, ecx, edx, sev_asid_count, sev_es_asid_count;
 	bool sev_es_supported = false;
 	bool sev_supported = false;
 
+	if (!sev_enabled || !npt_enabled)
+		goto out;
+
 	/* Does the CPU support SEV? */
 	if (!boot_cpu_has(X86_FEATURE_SEV))
 		goto out;
@@ -1376,12 +1840,12 @@ void __init sev_hardware_setup(void)
 
 	/* Maximum number of encrypted guests supported simultaneously */
 	max_sev_asid = ecx;
-
-	if (!svm_sev_enabled())
+	if (!max_sev_asid)
 		goto out;
 
 	/* Minimum ASID value that should be used for SEV guest */
 	min_sev_asid = edx;
+	sev_me_mask = 1UL << (ebx & 0x3f);
 
 	/* Initialize SEV ASID bitmaps */
 	sev_asid_bitmap = bitmap_zalloc(max_sev_asid, GFP_KERNEL);
@@ -1389,8 +1853,11 @@ void __init sev_hardware_setup(void)
 		goto out;
 
 	sev_reclaim_asid_bitmap = bitmap_zalloc(max_sev_asid, GFP_KERNEL);
-	if (!sev_reclaim_asid_bitmap)
+	if (!sev_reclaim_asid_bitmap) {
+		bitmap_free(sev_asid_bitmap);
+		sev_asid_bitmap = NULL;
 		goto out;
+	}
 
 	sev_asid_count = max_sev_asid - min_sev_asid + 1;
 	if (misc_cg_set_capacity(MISC_CG_RES_SEV, sev_asid_count))
@@ -1400,7 +1867,7 @@ void __init sev_hardware_setup(void)
 	sev_supported = true;
 
 	/* SEV-ES support requested? */
-	if (!sev_es)
+	if (!sev_es_enabled)
 		goto out;
 
 	/* Does the CPU support SEV-ES? */
@@ -1419,21 +1886,36 @@ void __init sev_hardware_setup(void)
 	sev_es_supported = true;
 
 out:
-	sev = sev_supported;
-	sev_es = sev_es_supported;
+	sev_enabled = sev_supported;
+	sev_es_enabled = sev_es_supported;
+#endif
 }
 
 void sev_hardware_teardown(void)
 {
-	if (!svm_sev_enabled())
+	if (!sev_enabled)
 		return;
 
+	/* No need to take sev_bitmap_lock, all VMs have been destroyed. */
+	sev_flush_asids(0, max_sev_asid);
+
 	bitmap_free(sev_asid_bitmap);
 	bitmap_free(sev_reclaim_asid_bitmap);
+
 	misc_cg_set_capacity(MISC_CG_RES_SEV, 0);
 	misc_cg_set_capacity(MISC_CG_RES_SEV_ES, 0);
+}
 
-	sev_flush_asids();
+int sev_cpu_init(struct svm_cpu_data *sd)
+{
+	if (!sev_enabled)
+		return 0;
+
+	sd->sev_vmcbs = kcalloc(max_sev_asid + 1, sizeof(void *), GFP_KERNEL);
+	if (!sd->sev_vmcbs)
+		return -ENOMEM;
+
+	return 0;
 }
 
 /*
@@ -1825,7 +2307,7 @@ static bool setup_vmgexit_scratch(struct vcpu_svm *svm, bool sync, u64 len)
 			       len, GHCB_SCRATCH_AREA_LIMIT);
 			return false;
 		}
-		scratch_va = kzalloc(len, GFP_KERNEL);
+		scratch_va = kzalloc(len, GFP_KERNEL_ACCOUNT);
 		if (!scratch_va)
 			return false;
 
@@ -1899,7 +2381,7 @@ static int sev_handle_vmgexit_msr_protocol(struct vcpu_svm *svm)
 		vcpu->arch.regs[VCPU_REGS_RAX] = cpuid_fn;
 		vcpu->arch.regs[VCPU_REGS_RCX] = 0;
 
-		ret = svm_invoke_exit_handler(svm, SVM_EXIT_CPUID);
+		ret = svm_invoke_exit_handler(vcpu, SVM_EXIT_CPUID);
 		if (!ret) {
 			ret = -EINVAL;
 			break;
@@ -1949,8 +2431,9 @@ static int sev_handle_vmgexit_msr_protocol(struct vcpu_svm *svm)
 	return ret;
 }
 
-int sev_handle_vmgexit(struct vcpu_svm *svm)
+int sev_handle_vmgexit(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	struct vmcb_control_area *control = &svm->vmcb->control;
 	u64 ghcb_gpa, exit_code;
 	struct ghcb *ghcb;
@@ -1962,13 +2445,13 @@ int sev_handle_vmgexit(struct vcpu_svm *svm)
 		return sev_handle_vmgexit_msr_protocol(svm);
 
 	if (!ghcb_gpa) {
-		vcpu_unimpl(&svm->vcpu, "vmgexit: GHCB gpa is not set\n");
+		vcpu_unimpl(vcpu, "vmgexit: GHCB gpa is not set\n");
 		return -EINVAL;
 	}
 
-	if (kvm_vcpu_map(&svm->vcpu, ghcb_gpa >> PAGE_SHIFT, &svm->ghcb_map)) {
+	if (kvm_vcpu_map(vcpu, ghcb_gpa >> PAGE_SHIFT, &svm->ghcb_map)) {
 		/* Unable to map GHCB from guest */
-		vcpu_unimpl(&svm->vcpu, "vmgexit: error mapping GHCB [%#llx] from guest\n",
+		vcpu_unimpl(vcpu, "vmgexit: error mapping GHCB [%#llx] from guest\n",
 			    ghcb_gpa);
 		return -EINVAL;
 	}
@@ -1976,7 +2459,7 @@ int sev_handle_vmgexit(struct vcpu_svm *svm)
 	svm->ghcb = svm->ghcb_map.hva;
 	ghcb = svm->ghcb_map.hva;
 
-	trace_kvm_vmgexit_enter(svm->vcpu.vcpu_id, ghcb);
+	trace_kvm_vmgexit_enter(vcpu->vcpu_id, ghcb);
 
 	exit_code = ghcb_get_sw_exit_code(ghcb);
 
@@ -1994,7 +2477,7 @@ int sev_handle_vmgexit(struct vcpu_svm *svm)
 		if (!setup_vmgexit_scratch(svm, true, control->exit_info_2))
 			break;
 
-		ret = kvm_sev_es_mmio_read(&svm->vcpu,
+		ret = kvm_sev_es_mmio_read(vcpu,
 					   control->exit_info_1,
 					   control->exit_info_2,
 					   svm->ghcb_sa);
@@ -2003,19 +2486,19 @@ int sev_handle_vmgexit(struct vcpu_svm *svm)
 		if (!setup_vmgexit_scratch(svm, false, control->exit_info_2))
 			break;
 
-		ret = kvm_sev_es_mmio_write(&svm->vcpu,
+		ret = kvm_sev_es_mmio_write(vcpu,
 					    control->exit_info_1,
 					    control->exit_info_2,
 					    svm->ghcb_sa);
 		break;
 	case SVM_VMGEXIT_NMI_COMPLETE:
-		ret = svm_invoke_exit_handler(svm, SVM_EXIT_IRET);
+		ret = svm_invoke_exit_handler(vcpu, SVM_EXIT_IRET);
 		break;
 	case SVM_VMGEXIT_AP_HLT_LOOP:
-		ret = kvm_emulate_ap_reset_hold(&svm->vcpu);
+		ret = kvm_emulate_ap_reset_hold(vcpu);
 		break;
 	case SVM_VMGEXIT_AP_JUMP_TABLE: {
-		struct kvm_sev_info *sev = &to_kvm_svm(svm->vcpu.kvm)->sev_info;
+		struct kvm_sev_info *sev = &to_kvm_svm(vcpu->kvm)->sev_info;
 
 		switch (control->exit_info_1) {
 		case 0:
@@ -2040,12 +2523,12 @@ int sev_handle_vmgexit(struct vcpu_svm *svm)
 		break;
 	}
 	case SVM_VMGEXIT_UNSUPPORTED_EVENT:
-		vcpu_unimpl(&svm->vcpu,
+		vcpu_unimpl(vcpu,
 			    "vmgexit: unsupported event - exit_info_1=%#llx, exit_info_2=%#llx\n",
 			    control->exit_info_1, control->exit_info_2);
 		break;
 	default:
-		ret = svm_invoke_exit_handler(svm, exit_code);
+		ret = svm_invoke_exit_handler(vcpu, exit_code);
 	}
 
 	return ret;
@@ -2154,5 +2637,8 @@ void sev_vcpu_deliver_sipi_vector(struct kvm_vcpu *vcpu, u8 vector)
 	 * the guest will set the CS and RIP. Set SW_EXIT_INFO_2 to a
 	 * non-zero value.
 	 */
+	if (!svm->ghcb)
+		return;
+
 	ghcb_set_sw_exit_info_2(svm->ghcb, 1);
 }
diff --git a/arch/x86/kvm/svm/svm.c b/arch/x86/kvm/svm/svm.c
index 6dad89248312..9790c73f2a32 100644
--- a/arch/x86/kvm/svm/svm.c
+++ b/arch/x86/kvm/svm/svm.c
@@ -56,9 +56,6 @@ static const struct x86_cpu_id svm_cpu_id[] = {
 MODULE_DEVICE_TABLE(x86cpu, svm_cpu_id);
 #endif
 
-#define IOPM_ALLOC_ORDER 2
-#define MSRPM_ALLOC_ORDER 1
-
 #define SEG_TYPE_LDT 2
 #define SEG_TYPE_BUSY_TSS16 3
 
@@ -95,6 +92,8 @@ static const struct svm_direct_access_msrs {
 } direct_access_msrs[MAX_DIRECT_ACCESS_MSRS] = {
 	{ .index = MSR_STAR,				.always = true  },
 	{ .index = MSR_IA32_SYSENTER_CS,		.always = true  },
+	{ .index = MSR_IA32_SYSENTER_EIP,		.always = false },
+	{ .index = MSR_IA32_SYSENTER_ESP,		.always = false },
 #ifdef CONFIG_X86_64
 	{ .index = MSR_GS_BASE,				.always = true  },
 	{ .index = MSR_FS_BASE,				.always = true  },
@@ -186,14 +185,6 @@ module_param(vls, int, 0444);
 static int vgif = true;
 module_param(vgif, int, 0444);
 
-/* enable/disable SEV support */
-int sev = IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT_ACTIVE_BY_DEFAULT);
-module_param(sev, int, 0444);
-
-/* enable/disable SEV-ES support */
-int sev_es = IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT_ACTIVE_BY_DEFAULT);
-module_param(sev_es, int, 0444);
-
 bool __read_mostly dump_invalid_vmcb;
 module_param(dump_invalid_vmcb, bool, 0644);
 
@@ -214,6 +205,15 @@ struct kvm_ldttss_desc {
 
 DEFINE_PER_CPU(struct svm_cpu_data *, svm_data);
 
+/*
+ * Only MSR_TSC_AUX is switched via the user return hook.  EFER is switched via
+ * the VMCB, and the SYSCALL/SYSENTER MSRs are handled by VMLOAD/VMSAVE.
+ *
+ * RDTSCP and RDPID are not used in the kernel, specifically to allow KVM to
+ * defer the restoration of TSC_AUX until the CPU returns to userspace.
+ */
+#define TSC_AUX_URET_SLOT	0
+
 static const u32 msrpm_ranges[] = {0, 0xc0000000, 0xc0010000};
 
 #define NUM_MSR_MAPS ARRAY_SIZE(msrpm_ranges)
@@ -279,7 +279,7 @@ int svm_set_efer(struct kvm_vcpu *vcpu, u64 efer)
 			 * In this case we will return to the nested guest
 			 * as soon as we leave SMM.
 			 */
-			if (!is_smm(&svm->vcpu))
+			if (!is_smm(vcpu))
 				svm_free_nested(svm);
 
 		} else {
@@ -363,10 +363,10 @@ static void svm_queue_exception(struct kvm_vcpu *vcpu)
 	bool has_error_code = vcpu->arch.exception.has_error_code;
 	u32 error_code = vcpu->arch.exception.error_code;
 
-	kvm_deliver_exception_payload(&svm->vcpu);
+	kvm_deliver_exception_payload(vcpu);
 
 	if (nr == BP_VECTOR && !nrips) {
-		unsigned long rip, old_rip = kvm_rip_read(&svm->vcpu);
+		unsigned long rip, old_rip = kvm_rip_read(vcpu);
 
 		/*
 		 * For guest debugging where we have to reinject #BP if some
@@ -375,8 +375,8 @@ static void svm_queue_exception(struct kvm_vcpu *vcpu)
 		 * raises a fault that is not intercepted. Still better than
 		 * failing in all cases.
 		 */
-		(void)skip_emulated_instruction(&svm->vcpu);
-		rip = kvm_rip_read(&svm->vcpu);
+		(void)skip_emulated_instruction(vcpu);
+		rip = kvm_rip_read(vcpu);
 		svm->int3_rip = rip + svm->vmcb->save.cs.base;
 		svm->int3_injected = rip - old_rip;
 	}
@@ -553,23 +553,21 @@ static void svm_cpu_uninit(int cpu)
 static int svm_cpu_init(int cpu)
 {
 	struct svm_cpu_data *sd;
+	int ret = -ENOMEM;
 
 	sd = kzalloc(sizeof(struct svm_cpu_data), GFP_KERNEL);
 	if (!sd)
-		return -ENOMEM;
+		return ret;
 	sd->cpu = cpu;
 	sd->save_area = alloc_page(GFP_KERNEL);
 	if (!sd->save_area)
 		goto free_cpu_data;
+
 	clear_page(page_address(sd->save_area));
 
-	if (svm_sev_enabled()) {
-		sd->sev_vmcbs = kmalloc_array(max_sev_asid + 1,
-					      sizeof(void *),
-					      GFP_KERNEL);
-		if (!sd->sev_vmcbs)
-			goto free_save_area;
-	}
+	ret = sev_cpu_init(sd);
+	if (ret)
+		goto free_save_area;
 
 	per_cpu(svm_data, cpu) = sd;
 
@@ -579,7 +577,7 @@ free_save_area:
 	__free_page(sd->save_area);
 free_cpu_data:
 	kfree(sd);
-	return -ENOMEM;
+	return ret;
 
 }
 
@@ -681,14 +679,15 @@ void set_msr_interception(struct kvm_vcpu *vcpu, u32 *msrpm, u32 msr,
 
 u32 *svm_vcpu_alloc_msrpm(void)
 {
-	struct page *pages = alloc_pages(GFP_KERNEL_ACCOUNT, MSRPM_ALLOC_ORDER);
+	unsigned int order = get_order(MSRPM_SIZE);
+	struct page *pages = alloc_pages(GFP_KERNEL_ACCOUNT, order);
 	u32 *msrpm;
 
 	if (!pages)
 		return NULL;
 
 	msrpm = page_address(pages);
-	memset(msrpm, 0xff, PAGE_SIZE * (1 << MSRPM_ALLOC_ORDER));
+	memset(msrpm, 0xff, PAGE_SIZE * (1 << order));
 
 	return msrpm;
 }
@@ -707,7 +706,7 @@ void svm_vcpu_init_msrpm(struct kvm_vcpu *vcpu, u32 *msrpm)
 
 void svm_vcpu_free_msrpm(u32 *msrpm)
 {
-	__free_pages(virt_to_page(msrpm), MSRPM_ALLOC_ORDER);
+	__free_pages(virt_to_page(msrpm), get_order(MSRPM_SIZE));
 }
 
 static void svm_msr_filter_changed(struct kvm_vcpu *vcpu)
@@ -881,20 +880,20 @@ static __init void svm_adjust_mmio_mask(void)
 	 */
 	mask = (mask_bit < 52) ? rsvd_bits(mask_bit, 51) | PT_PRESENT_MASK : 0;
 
-	kvm_mmu_set_mmio_spte_mask(mask, PT_WRITABLE_MASK | PT_USER_MASK);
+	kvm_mmu_set_mmio_spte_mask(mask, mask, PT_WRITABLE_MASK | PT_USER_MASK);
 }
 
 static void svm_hardware_teardown(void)
 {
 	int cpu;
 
-	if (svm_sev_enabled())
-		sev_hardware_teardown();
+	sev_hardware_teardown();
 
 	for_each_possible_cpu(cpu)
 		svm_cpu_uninit(cpu);
 
-	__free_pages(pfn_to_page(iopm_base >> PAGE_SHIFT), IOPM_ALLOC_ORDER);
+	__free_pages(pfn_to_page(iopm_base >> PAGE_SHIFT),
+	get_order(IOPM_SIZE));
 	iopm_base = 0;
 }
 
@@ -922,6 +921,9 @@ static __init void svm_set_cpu_caps(void)
 	if (boot_cpu_has(X86_FEATURE_LS_CFG_SSBD) ||
 	    boot_cpu_has(X86_FEATURE_AMD_SSBD))
 		kvm_cpu_cap_set(X86_FEATURE_VIRT_SSBD);
+
+	/* CPUID 0x8000001F (SME/SEV features) */
+	sev_set_cpu_caps();
 }
 
 static __init int svm_hardware_setup(void)
@@ -930,14 +932,15 @@ static __init int svm_hardware_setup(void)
 	struct page *iopm_pages;
 	void *iopm_va;
 	int r;
+	unsigned int order = get_order(IOPM_SIZE);
 
-	iopm_pages = alloc_pages(GFP_KERNEL, IOPM_ALLOC_ORDER);
+	iopm_pages = alloc_pages(GFP_KERNEL, order);
 
 	if (!iopm_pages)
 		return -ENOMEM;
 
 	iopm_va = page_address(iopm_pages);
-	memset(iopm_va, 0xff, PAGE_SIZE * (1 << IOPM_ALLOC_ORDER));
+	memset(iopm_va, 0xff, PAGE_SIZE * (1 << order));
 	iopm_base = page_to_pfn(iopm_pages) << PAGE_SHIFT;
 
 	init_msrpm_offsets();
@@ -956,6 +959,9 @@ static __init int svm_hardware_setup(void)
 		kvm_tsc_scaling_ratio_frac_bits = 32;
 	}
 
+	if (boot_cpu_has(X86_FEATURE_RDTSCP))
+		kvm_define_user_return_msr(TSC_AUX_URET_SLOT, MSR_TSC_AUX);
+
 	/* Check for pause filtering support */
 	if (!boot_cpu_has(X86_FEATURE_PAUSEFILTER)) {
 		pause_filter_count = 0;
@@ -969,21 +975,6 @@ static __init int svm_hardware_setup(void)
 		kvm_enable_efer_bits(EFER_SVME | EFER_LMSLE);
 	}
 
-	if (IS_ENABLED(CONFIG_KVM_AMD_SEV) && sev) {
-		sev_hardware_setup();
-	} else {
-		sev = false;
-		sev_es = false;
-	}
-
-	svm_adjust_mmio_mask();
-
-	for_each_possible_cpu(cpu) {
-		r = svm_cpu_init(cpu);
-		if (r)
-			goto err;
-	}
-
 	/*
 	 * KVM's MMU doesn't support using 2-level paging for itself, and thus
 	 * NPT isn't supported if the host is using 2-level paging since host
@@ -998,6 +989,17 @@ static __init int svm_hardware_setup(void)
 	kvm_configure_mmu(npt_enabled, get_max_npt_level(), PG_LEVEL_1G);
 	pr_info("kvm: Nested Paging %sabled\n", npt_enabled ? "en" : "dis");
 
+	/* Note, SEV setup consumes npt_enabled. */
+	sev_hardware_setup();
+
+	svm_adjust_mmio_mask();
+
+	for_each_possible_cpu(cpu) {
+		r = svm_cpu_init(cpu);
+		if (r)
+			goto err;
+	}
+
 	if (nrips) {
 		if (!boot_cpu_has(X86_FEATURE_NRIPS))
 			nrips = false;
@@ -1084,8 +1086,8 @@ static u64 svm_write_l1_tsc_offset(struct kvm_vcpu *vcpu, u64 offset)
 	if (is_guest_mode(vcpu)) {
 		/* Write L1's TSC offset.  */
 		g_tsc_offset = svm->vmcb->control.tsc_offset -
-			       svm->nested.hsave->control.tsc_offset;
-		svm->nested.hsave->control.tsc_offset = offset;
+			       svm->vmcb01.ptr->control.tsc_offset;
+		svm->vmcb01.ptr->control.tsc_offset = offset;
 	}
 
 	trace_kvm_write_tsc_offset(vcpu->vcpu_id,
@@ -1113,12 +1115,13 @@ static void svm_check_invpcid(struct vcpu_svm *svm)
 	}
 }
 
-static void init_vmcb(struct vcpu_svm *svm)
+static void init_vmcb(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	struct vmcb_control_area *control = &svm->vmcb->control;
 	struct vmcb_save_area *save = &svm->vmcb->save;
 
-	svm->vcpu.arch.hflags = 0;
+	vcpu->arch.hflags = 0;
 
 	svm_set_intercept(svm, INTERCEPT_CR0_READ);
 	svm_set_intercept(svm, INTERCEPT_CR3_READ);
@@ -1126,7 +1129,7 @@ static void init_vmcb(struct vcpu_svm *svm)
 	svm_set_intercept(svm, INTERCEPT_CR0_WRITE);
 	svm_set_intercept(svm, INTERCEPT_CR3_WRITE);
 	svm_set_intercept(svm, INTERCEPT_CR4_WRITE);
-	if (!kvm_vcpu_apicv_active(&svm->vcpu))
+	if (!kvm_vcpu_apicv_active(vcpu))
 		svm_set_intercept(svm, INTERCEPT_CR8_WRITE);
 
 	set_dr_intercepts(svm);
@@ -1170,12 +1173,12 @@ static void init_vmcb(struct vcpu_svm *svm)
 	svm_set_intercept(svm, INTERCEPT_RDPRU);
 	svm_set_intercept(svm, INTERCEPT_RSM);
 
-	if (!kvm_mwait_in_guest(svm->vcpu.kvm)) {
+	if (!kvm_mwait_in_guest(vcpu->kvm)) {
 		svm_set_intercept(svm, INTERCEPT_MONITOR);
 		svm_set_intercept(svm, INTERCEPT_MWAIT);
 	}
 
-	if (!kvm_hlt_in_guest(svm->vcpu.kvm))
+	if (!kvm_hlt_in_guest(vcpu->kvm))
 		svm_set_intercept(svm, INTERCEPT_HLT);
 
 	control->iopm_base_pa = __sme_set(iopm_base);
@@ -1201,19 +1204,19 @@ static void init_vmcb(struct vcpu_svm *svm)
 	init_sys_seg(&save->ldtr, SEG_TYPE_LDT);
 	init_sys_seg(&save->tr, SEG_TYPE_BUSY_TSS16);
 
-	svm_set_cr4(&svm->vcpu, 0);
-	svm_set_efer(&svm->vcpu, 0);
+	svm_set_cr4(vcpu, 0);
+	svm_set_efer(vcpu, 0);
 	save->dr6 = 0xffff0ff0;
-	kvm_set_rflags(&svm->vcpu, X86_EFLAGS_FIXED);
+	kvm_set_rflags(vcpu, X86_EFLAGS_FIXED);
 	save->rip = 0x0000fff0;
-	svm->vcpu.arch.regs[VCPU_REGS_RIP] = save->rip;
+	vcpu->arch.regs[VCPU_REGS_RIP] = save->rip;
 
 	/*
 	 * svm_set_cr0() sets PG and WP and clears NW and CD on save->cr0.
 	 * It also updates the guest-visible cr0 value.
 	 */
-	svm_set_cr0(&svm->vcpu, X86_CR0_NW | X86_CR0_CD | X86_CR0_ET);
-	kvm_mmu_reset_context(&svm->vcpu);
+	svm_set_cr0(vcpu, X86_CR0_NW | X86_CR0_CD | X86_CR0_ET);
+	kvm_mmu_reset_context(vcpu);
 
 	save->cr4 = X86_CR4_PAE;
 	/* rdx = ?? */
@@ -1225,17 +1228,18 @@ static void init_vmcb(struct vcpu_svm *svm)
 		clr_exception_intercept(svm, PF_VECTOR);
 		svm_clr_intercept(svm, INTERCEPT_CR3_READ);
 		svm_clr_intercept(svm, INTERCEPT_CR3_WRITE);
-		save->g_pat = svm->vcpu.arch.pat;
+		save->g_pat = vcpu->arch.pat;
 		save->cr3 = 0;
 		save->cr4 = 0;
 	}
-	svm->asid_generation = 0;
+	svm->current_vmcb->asid_generation = 0;
 	svm->asid = 0;
 
 	svm->nested.vmcb12_gpa = 0;
-	svm->vcpu.arch.hflags = 0;
+	svm->nested.last_vmcb12_gpa = 0;
+	vcpu->arch.hflags = 0;
 
-	if (!kvm_pause_in_guest(svm->vcpu.kvm)) {
+	if (!kvm_pause_in_guest(vcpu->kvm)) {
 		control->pause_filter_count = pause_filter_count;
 		if (pause_filter_thresh)
 			control->pause_filter_thresh = pause_filter_thresh;
@@ -1246,18 +1250,15 @@ static void init_vmcb(struct vcpu_svm *svm)
 
 	svm_check_invpcid(svm);
 
-	if (kvm_vcpu_apicv_active(&svm->vcpu))
-		avic_init_vmcb(svm);
-
 	/*
-	 * If hardware supports Virtual VMLOAD VMSAVE then enable it
-	 * in VMCB and clear intercepts to avoid #VMEXIT.
+	 * If the host supports V_SPEC_CTRL then disable the interception
+	 * of MSR_IA32_SPEC_CTRL.
 	 */
-	if (vls) {
-		svm_clr_intercept(svm, INTERCEPT_VMLOAD);
-		svm_clr_intercept(svm, INTERCEPT_VMSAVE);
-		svm->vmcb->control.virt_ext |= VIRTUAL_VMLOAD_VMSAVE_ENABLE_MASK;
-	}
+	if (boot_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SPEC_CTRL, 1, 1);
+
+	if (kvm_vcpu_apicv_active(vcpu))
+		avic_init_vmcb(svm);
 
 	if (vgif) {
 		svm_clr_intercept(svm, INTERCEPT_STGI);
@@ -1265,11 +1266,11 @@ static void init_vmcb(struct vcpu_svm *svm)
 		svm->vmcb->control.int_ctl |= V_GIF_ENABLE_MASK;
 	}
 
-	if (sev_guest(svm->vcpu.kvm)) {
+	if (sev_guest(vcpu->kvm)) {
 		svm->vmcb->control.nested_ctl |= SVM_NESTED_CTL_SEV_ENABLE;
 		clr_exception_intercept(svm, UD_VECTOR);
 
-		if (sev_es_guest(svm->vcpu.kvm)) {
+		if (sev_es_guest(vcpu->kvm)) {
 			/* Perform SEV-ES specific VMCB updates */
 			sev_es_init_vmcb(svm);
 		}
@@ -1291,12 +1292,12 @@ static void svm_vcpu_reset(struct kvm_vcpu *vcpu, bool init_event)
 	svm->virt_spec_ctrl = 0;
 
 	if (!init_event) {
-		svm->vcpu.arch.apic_base = APIC_DEFAULT_PHYS_BASE |
-					   MSR_IA32_APICBASE_ENABLE;
-		if (kvm_vcpu_is_reset_bsp(&svm->vcpu))
-			svm->vcpu.arch.apic_base |= MSR_IA32_APICBASE_BSP;
+		vcpu->arch.apic_base = APIC_DEFAULT_PHYS_BASE |
+				       MSR_IA32_APICBASE_ENABLE;
+		if (kvm_vcpu_is_reset_bsp(vcpu))
+			vcpu->arch.apic_base |= MSR_IA32_APICBASE_BSP;
 	}
-	init_vmcb(svm);
+	init_vmcb(vcpu);
 
 	kvm_cpuid(vcpu, &eax, &dummy, &dummy, &dummy, false);
 	kvm_rdx_write(vcpu, eax);
@@ -1305,10 +1306,16 @@ static void svm_vcpu_reset(struct kvm_vcpu *vcpu, bool init_event)
 		avic_update_vapic_bar(svm, APIC_DEFAULT_PHYS_BASE);
 }
 
+void svm_switch_vmcb(struct vcpu_svm *svm, struct kvm_vmcb_info *target_vmcb)
+{
+	svm->current_vmcb = target_vmcb;
+	svm->vmcb = target_vmcb->ptr;
+}
+
 static int svm_create_vcpu(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm;
-	struct page *vmcb_page;
+	struct page *vmcb01_page;
 	struct page *vmsa_page = NULL;
 	int err;
 
@@ -1316,11 +1323,11 @@ static int svm_create_vcpu(struct kvm_vcpu *vcpu)
 	svm = to_svm(vcpu);
 
 	err = -ENOMEM;
-	vmcb_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
-	if (!vmcb_page)
+	vmcb01_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+	if (!vmcb01_page)
 		goto out;
 
-	if (sev_es_guest(svm->vcpu.kvm)) {
+	if (sev_es_guest(vcpu->kvm)) {
 		/*
 		 * SEV-ES guests require a separate VMSA page used to contain
 		 * the encrypted register state of the guest.
@@ -1356,20 +1363,21 @@ static int svm_create_vcpu(struct kvm_vcpu *vcpu)
 
 	svm_vcpu_init_msrpm(vcpu, svm->msrpm);
 
-	svm->vmcb = page_address(vmcb_page);
-	svm->vmcb_pa = __sme_set(page_to_pfn(vmcb_page) << PAGE_SHIFT);
+	svm->vmcb01.ptr = page_address(vmcb01_page);
+	svm->vmcb01.pa = __sme_set(page_to_pfn(vmcb01_page) << PAGE_SHIFT);
 
 	if (vmsa_page)
 		svm->vmsa = page_address(vmsa_page);
 
-	svm->asid_generation = 0;
 	svm->guest_state_loaded = false;
-	init_vmcb(svm);
+
+	svm_switch_vmcb(svm, &svm->vmcb01);
+	init_vmcb(vcpu);
 
 	svm_init_osvw(vcpu);
 	vcpu->arch.microcode_version = 0x01000065;
 
-	if (sev_es_guest(svm->vcpu.kvm))
+	if (sev_es_guest(vcpu->kvm))
 		/* Perform SEV-ES specific VMCB creation updates */
 		sev_es_create_vcpu(svm);
 
@@ -1379,7 +1387,7 @@ error_free_vmsa_page:
 	if (vmsa_page)
 		__free_page(vmsa_page);
 error_free_vmcb_page:
-	__free_page(vmcb_page);
+	__free_page(vmcb01_page);
 out:
 	return err;
 }
@@ -1407,32 +1415,23 @@ static void svm_free_vcpu(struct kvm_vcpu *vcpu)
 
 	sev_free_vcpu(vcpu);
 
-	__free_page(pfn_to_page(__sme_clr(svm->vmcb_pa) >> PAGE_SHIFT));
-	__free_pages(virt_to_page(svm->msrpm), MSRPM_ALLOC_ORDER);
+	__free_page(pfn_to_page(__sme_clr(svm->vmcb01.pa) >> PAGE_SHIFT));
+	__free_pages(virt_to_page(svm->msrpm), get_order(MSRPM_SIZE));
 }
 
 static void svm_prepare_guest_switch(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 	struct svm_cpu_data *sd = per_cpu(svm_data, vcpu->cpu);
-	unsigned int i;
 
 	if (svm->guest_state_loaded)
 		return;
 
 	/*
-	 * Certain MSRs are restored on VMEXIT (sev-es), or vmload of host save
-	 * area (non-sev-es). Save ones that aren't so we can restore them
-	 * individually later.
-	 */
-	for (i = 0; i < NR_HOST_SAVE_USER_MSRS; i++)
-		rdmsrl(host_save_user_msrs[i], svm->host_user_msrs[i]);
-
-	/*
 	 * Save additional host state that will be restored on VMEXIT (sev-es)
 	 * or subsequent vmload of host save area.
 	 */
-	if (sev_es_guest(svm->vcpu.kvm)) {
+	if (sev_es_guest(vcpu->kvm)) {
 		sev_es_prepare_guest_switch(svm, vcpu->cpu);
 	} else {
 		vmsave(__sme_page_pa(sd->save_area));
@@ -1446,29 +1445,15 @@ static void svm_prepare_guest_switch(struct kvm_vcpu *vcpu)
 		}
 	}
 
-	/* This assumes that the kernel never uses MSR_TSC_AUX */
 	if (static_cpu_has(X86_FEATURE_RDTSCP))
-		wrmsrl(MSR_TSC_AUX, svm->tsc_aux);
+		kvm_set_user_return_msr(TSC_AUX_URET_SLOT, svm->tsc_aux, -1ull);
 
 	svm->guest_state_loaded = true;
 }
 
 static void svm_prepare_host_switch(struct kvm_vcpu *vcpu)
 {
-	struct vcpu_svm *svm = to_svm(vcpu);
-	unsigned int i;
-
-	if (!svm->guest_state_loaded)
-		return;
-
-	/*
-	 * Certain MSRs are restored on VMEXIT (sev-es), or vmload of host save
-	 * area (non-sev-es). Restore the ones that weren't.
-	 */
-	for (i = 0; i < NR_HOST_SAVE_USER_MSRS; i++)
-		wrmsrl(host_save_user_msrs[i], svm->host_user_msrs[i]);
-
-	svm->guest_state_loaded = false;
+	to_svm(vcpu)->guest_state_loaded = false;
 }
 
 static void svm_vcpu_load(struct kvm_vcpu *vcpu, int cpu)
@@ -1476,11 +1461,6 @@ static void svm_vcpu_load(struct kvm_vcpu *vcpu, int cpu)
 	struct vcpu_svm *svm = to_svm(vcpu);
 	struct svm_cpu_data *sd = per_cpu(svm_data, cpu);
 
-	if (unlikely(cpu != vcpu->cpu)) {
-		svm->asid_generation = 0;
-		vmcb_mark_all_dirty(svm->vmcb);
-	}
-
 	if (sd->current_vmcb != svm->vmcb) {
 		sd->current_vmcb = svm->vmcb;
 		indirect_branch_prediction_barrier();
@@ -1564,7 +1544,7 @@ static void svm_clear_vintr(struct vcpu_svm *svm)
 	/* Drop int_ctl fields related to VINTR injection.  */
 	svm->vmcb->control.int_ctl &= mask;
 	if (is_guest_mode(&svm->vcpu)) {
-		svm->nested.hsave->control.int_ctl &= mask;
+		svm->vmcb01.ptr->control.int_ctl &= mask;
 
 		WARN_ON((svm->vmcb->control.int_ctl & V_TPR_MASK) !=
 			(svm->nested.ctl.int_ctl & V_TPR_MASK));
@@ -1577,16 +1557,17 @@ static void svm_clear_vintr(struct vcpu_svm *svm)
 static struct vmcb_seg *svm_seg(struct kvm_vcpu *vcpu, int seg)
 {
 	struct vmcb_save_area *save = &to_svm(vcpu)->vmcb->save;
+	struct vmcb_save_area *save01 = &to_svm(vcpu)->vmcb01.ptr->save;
 
 	switch (seg) {
 	case VCPU_SREG_CS: return &save->cs;
 	case VCPU_SREG_DS: return &save->ds;
 	case VCPU_SREG_ES: return &save->es;
-	case VCPU_SREG_FS: return &save->fs;
-	case VCPU_SREG_GS: return &save->gs;
+	case VCPU_SREG_FS: return &save01->fs;
+	case VCPU_SREG_GS: return &save01->gs;
 	case VCPU_SREG_SS: return &save->ss;
-	case VCPU_SREG_TR: return &save->tr;
-	case VCPU_SREG_LDTR: return &save->ldtr;
+	case VCPU_SREG_TR: return &save01->tr;
+	case VCPU_SREG_LDTR: return &save01->ldtr;
 	}
 	BUG();
 	return NULL;
@@ -1709,37 +1690,10 @@ static void svm_set_gdt(struct kvm_vcpu *vcpu, struct desc_ptr *dt)
 	vmcb_mark_dirty(svm->vmcb, VMCB_DT);
 }
 
-static void update_cr0_intercept(struct vcpu_svm *svm)
-{
-	ulong gcr0;
-	u64 *hcr0;
-
-	/*
-	 * SEV-ES guests must always keep the CR intercepts cleared. CR
-	 * tracking is done using the CR write traps.
-	 */
-	if (sev_es_guest(svm->vcpu.kvm))
-		return;
-
-	gcr0 = svm->vcpu.arch.cr0;
-	hcr0 = &svm->vmcb->save.cr0;
-	*hcr0 = (*hcr0 & ~SVM_CR0_SELECTIVE_MASK)
-		| (gcr0 & SVM_CR0_SELECTIVE_MASK);
-
-	vmcb_mark_dirty(svm->vmcb, VMCB_CR);
-
-	if (gcr0 == *hcr0) {
-		svm_clr_intercept(svm, INTERCEPT_CR0_READ);
-		svm_clr_intercept(svm, INTERCEPT_CR0_WRITE);
-	} else {
-		svm_set_intercept(svm, INTERCEPT_CR0_READ);
-		svm_set_intercept(svm, INTERCEPT_CR0_WRITE);
-	}
-}
-
 void svm_set_cr0(struct kvm_vcpu *vcpu, unsigned long cr0)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
+	u64 hcr0 = cr0;
 
 #ifdef CONFIG_X86_64
 	if (vcpu->arch.efer & EFER_LME && !vcpu->arch.guest_state_protected) {
@@ -1757,7 +1711,7 @@ void svm_set_cr0(struct kvm_vcpu *vcpu, unsigned long cr0)
 	vcpu->arch.cr0 = cr0;
 
 	if (!npt_enabled)
-		cr0 |= X86_CR0_PG | X86_CR0_WP;
+		hcr0 |= X86_CR0_PG | X86_CR0_WP;
 
 	/*
 	 * re-enable caching here because the QEMU bios
@@ -1765,10 +1719,26 @@ void svm_set_cr0(struct kvm_vcpu *vcpu, unsigned long cr0)
 	 * reboot
 	 */
 	if (kvm_check_has_quirk(vcpu->kvm, KVM_X86_QUIRK_CD_NW_CLEARED))
-		cr0 &= ~(X86_CR0_CD | X86_CR0_NW);
-	svm->vmcb->save.cr0 = cr0;
+		hcr0 &= ~(X86_CR0_CD | X86_CR0_NW);
+
+	svm->vmcb->save.cr0 = hcr0;
 	vmcb_mark_dirty(svm->vmcb, VMCB_CR);
-	update_cr0_intercept(svm);
+
+	/*
+	 * SEV-ES guests must always keep the CR intercepts cleared. CR
+	 * tracking is done using the CR write traps.
+	 */
+	if (sev_es_guest(vcpu->kvm))
+		return;
+
+	if (hcr0 == cr0) {
+		/* Selective CR0 write remains on.  */
+		svm_clr_intercept(svm, INTERCEPT_CR0_READ);
+		svm_clr_intercept(svm, INTERCEPT_CR0_WRITE);
+	} else {
+		svm_set_intercept(svm, INTERCEPT_CR0_READ);
+		svm_set_intercept(svm, INTERCEPT_CR0_WRITE);
+	}
 }
 
 static bool svm_is_valid_cr4(struct kvm_vcpu *vcpu, unsigned long cr4)
@@ -1847,7 +1817,7 @@ static void new_asid(struct vcpu_svm *svm, struct svm_cpu_data *sd)
 		vmcb_mark_dirty(svm->vmcb, VMCB_ASID);
 	}
 
-	svm->asid_generation = sd->asid_generation;
+	svm->current_vmcb->asid_generation = sd->asid_generation;
 	svm->asid = sd->next_asid++;
 }
 
@@ -1896,39 +1866,43 @@ static void svm_set_dr7(struct kvm_vcpu *vcpu, unsigned long value)
 	vmcb_mark_dirty(svm->vmcb, VMCB_DR);
 }
 
-static int pf_interception(struct vcpu_svm *svm)
+static int pf_interception(struct kvm_vcpu *vcpu)
 {
-	u64 fault_address = __sme_clr(svm->vmcb->control.exit_info_2);
+	struct vcpu_svm *svm = to_svm(vcpu);
+
+	u64 fault_address = svm->vmcb->control.exit_info_2;
 	u64 error_code = svm->vmcb->control.exit_info_1;
 
-	return kvm_handle_page_fault(&svm->vcpu, error_code, fault_address,
+	return kvm_handle_page_fault(vcpu, error_code, fault_address,
 			static_cpu_has(X86_FEATURE_DECODEASSISTS) ?
 			svm->vmcb->control.insn_bytes : NULL,
 			svm->vmcb->control.insn_len);
 }
 
-static int npf_interception(struct vcpu_svm *svm)
+static int npf_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
+
 	u64 fault_address = __sme_clr(svm->vmcb->control.exit_info_2);
 	u64 error_code = svm->vmcb->control.exit_info_1;
 
 	trace_kvm_page_fault(fault_address, error_code);
-	return kvm_mmu_page_fault(&svm->vcpu, fault_address, error_code,
+	return kvm_mmu_page_fault(vcpu, fault_address, error_code,
 			static_cpu_has(X86_FEATURE_DECODEASSISTS) ?
 			svm->vmcb->control.insn_bytes : NULL,
 			svm->vmcb->control.insn_len);
 }
 
-static int db_interception(struct vcpu_svm *svm)
+static int db_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct kvm_run *kvm_run = vcpu->run;
+	struct vcpu_svm *svm = to_svm(vcpu);
 
-	if (!(svm->vcpu.guest_debug &
+	if (!(vcpu->guest_debug &
 	      (KVM_GUESTDBG_SINGLESTEP | KVM_GUESTDBG_USE_HW_BP)) &&
 		!svm->nmi_singlestep) {
 		u32 payload = svm->vmcb->save.dr6 ^ DR6_ACTIVE_LOW;
-		kvm_queue_exception_p(&svm->vcpu, DB_VECTOR, payload);
+		kvm_queue_exception_p(vcpu, DB_VECTOR, payload);
 		return 1;
 	}
 
@@ -1938,7 +1912,7 @@ static int db_interception(struct vcpu_svm *svm)
 		kvm_make_request(KVM_REQ_EVENT, vcpu);
 	}
 
-	if (svm->vcpu.guest_debug &
+	if (vcpu->guest_debug &
 	    (KVM_GUESTDBG_SINGLESTEP | KVM_GUESTDBG_USE_HW_BP)) {
 		kvm_run->exit_reason = KVM_EXIT_DEBUG;
 		kvm_run->debug.arch.dr6 = svm->vmcb->save.dr6;
@@ -1952,9 +1926,10 @@ static int db_interception(struct vcpu_svm *svm)
 	return 1;
 }
 
-static int bp_interception(struct vcpu_svm *svm)
+static int bp_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
+	struct vcpu_svm *svm = to_svm(vcpu);
+	struct kvm_run *kvm_run = vcpu->run;
 
 	kvm_run->exit_reason = KVM_EXIT_DEBUG;
 	kvm_run->debug.arch.pc = svm->vmcb->save.cs.base + svm->vmcb->save.rip;
@@ -1962,14 +1937,14 @@ static int bp_interception(struct vcpu_svm *svm)
 	return 0;
 }
 
-static int ud_interception(struct vcpu_svm *svm)
+static int ud_interception(struct kvm_vcpu *vcpu)
 {
-	return handle_ud(&svm->vcpu);
+	return handle_ud(vcpu);
 }
 
-static int ac_interception(struct vcpu_svm *svm)
+static int ac_interception(struct kvm_vcpu *vcpu)
 {
-	kvm_queue_exception_e(&svm->vcpu, AC_VECTOR, 0);
+	kvm_queue_exception_e(vcpu, AC_VECTOR, 0);
 	return 1;
 }
 
@@ -2012,7 +1987,7 @@ static bool is_erratum_383(void)
 	return true;
 }
 
-static void svm_handle_mce(struct vcpu_svm *svm)
+static void svm_handle_mce(struct kvm_vcpu *vcpu)
 {
 	if (is_erratum_383()) {
 		/*
@@ -2021,7 +1996,7 @@ static void svm_handle_mce(struct vcpu_svm *svm)
 		 */
 		pr_err("KVM: Guest triggered AMD Erratum 383\n");
 
-		kvm_make_request(KVM_REQ_TRIPLE_FAULT, &svm->vcpu);
+		kvm_make_request(KVM_REQ_TRIPLE_FAULT, vcpu);
 
 		return;
 	}
@@ -2033,20 +2008,21 @@ static void svm_handle_mce(struct vcpu_svm *svm)
 	kvm_machine_check();
 }
 
-static int mc_interception(struct vcpu_svm *svm)
+static int mc_interception(struct kvm_vcpu *vcpu)
 {
 	return 1;
 }
 
-static int shutdown_interception(struct vcpu_svm *svm)
+static int shutdown_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
+	struct kvm_run *kvm_run = vcpu->run;
+	struct vcpu_svm *svm = to_svm(vcpu);
 
 	/*
 	 * The VM save area has already been encrypted so it
 	 * cannot be reinitialized - just terminate.
 	 */
-	if (sev_es_guest(svm->vcpu.kvm))
+	if (sev_es_guest(vcpu->kvm))
 		return -EINVAL;
 
 	/*
@@ -2054,20 +2030,20 @@ static int shutdown_interception(struct vcpu_svm *svm)
 	 * so reinitialize it.
 	 */
 	clear_page(svm->vmcb);
-	init_vmcb(svm);
+	init_vmcb(vcpu);
 
 	kvm_run->exit_reason = KVM_EXIT_SHUTDOWN;
 	return 0;
 }
 
-static int io_interception(struct vcpu_svm *svm)
+static int io_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u32 io_info = svm->vmcb->control.exit_info_1; /* address size bug? */
 	int size, in, string;
 	unsigned port;
 
-	++svm->vcpu.stat.io_exits;
+	++vcpu->stat.io_exits;
 	string = (io_info & SVM_IOIO_STR_MASK) != 0;
 	in = (io_info & SVM_IOIO_TYPE_MASK) != 0;
 	port = io_info >> 16;
@@ -2082,93 +2058,69 @@ static int io_interception(struct vcpu_svm *svm)
 
 	svm->next_rip = svm->vmcb->control.exit_info_2;
 
-	return kvm_fast_pio(&svm->vcpu, size, port, in);
-}
-
-static int nmi_interception(struct vcpu_svm *svm)
-{
-	return 1;
+	return kvm_fast_pio(vcpu, size, port, in);
 }
 
-static int intr_interception(struct vcpu_svm *svm)
+static int nmi_interception(struct kvm_vcpu *vcpu)
 {
-	++svm->vcpu.stat.irq_exits;
 	return 1;
 }
 
-static int nop_on_interception(struct vcpu_svm *svm)
+static int intr_interception(struct kvm_vcpu *vcpu)
 {
+	++vcpu->stat.irq_exits;
 	return 1;
 }
 
-static int halt_interception(struct vcpu_svm *svm)
+static int vmload_vmsave_interception(struct kvm_vcpu *vcpu, bool vmload)
 {
-	return kvm_emulate_halt(&svm->vcpu);
-}
-
-static int vmmcall_interception(struct vcpu_svm *svm)
-{
-	return kvm_emulate_hypercall(&svm->vcpu);
-}
-
-static int vmload_interception(struct vcpu_svm *svm)
-{
-	struct vmcb *nested_vmcb;
+	struct vcpu_svm *svm = to_svm(vcpu);
+	struct vmcb *vmcb12;
 	struct kvm_host_map map;
 	int ret;
 
-	if (nested_svm_check_permissions(svm))
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	ret = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(svm->vmcb->save.rax), &map);
+	ret = kvm_vcpu_map(vcpu, gpa_to_gfn(svm->vmcb->save.rax), &map);
 	if (ret) {
 		if (ret == -EINVAL)
-			kvm_inject_gp(&svm->vcpu, 0);
+			kvm_inject_gp(vcpu, 0);
 		return 1;
 	}
 
-	nested_vmcb = map.hva;
+	vmcb12 = map.hva;
+
+	ret = kvm_skip_emulated_instruction(vcpu);
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
+	if (vmload) {
+		nested_svm_vmloadsave(vmcb12, svm->vmcb);
+		svm->sysenter_eip_hi = 0;
+		svm->sysenter_esp_hi = 0;
+	} else
+		nested_svm_vmloadsave(svm->vmcb, vmcb12);
 
-	nested_svm_vmloadsave(nested_vmcb, svm->vmcb);
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
+	kvm_vcpu_unmap(vcpu, &map, true);
 
 	return ret;
 }
 
-static int vmsave_interception(struct vcpu_svm *svm)
+static int vmload_interception(struct kvm_vcpu *vcpu)
 {
-	struct vmcb *nested_vmcb;
-	struct kvm_host_map map;
-	int ret;
-
-	if (nested_svm_check_permissions(svm))
-		return 1;
-
-	ret = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(svm->vmcb->save.rax), &map);
-	if (ret) {
-		if (ret == -EINVAL)
-			kvm_inject_gp(&svm->vcpu, 0);
-		return 1;
-	}
-
-	nested_vmcb = map.hva;
-
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
-
-	nested_svm_vmloadsave(svm->vmcb, nested_vmcb);
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
+	return vmload_vmsave_interception(vcpu, true);
+}
 
-	return ret;
+static int vmsave_interception(struct kvm_vcpu *vcpu)
+{
+	return vmload_vmsave_interception(vcpu, false);
 }
 
-static int vmrun_interception(struct vcpu_svm *svm)
+static int vmrun_interception(struct kvm_vcpu *vcpu)
 {
-	if (nested_svm_check_permissions(svm))
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	return nested_svm_vmrun(svm);
+	return nested_svm_vmrun(vcpu);
 }
 
 enum {
@@ -2207,7 +2159,7 @@ static int emulate_svm_instr(struct kvm_vcpu *vcpu, int opcode)
 		[SVM_INSTR_VMLOAD] = SVM_EXIT_VMLOAD,
 		[SVM_INSTR_VMSAVE] = SVM_EXIT_VMSAVE,
 	};
-	int (*const svm_instr_handlers[])(struct vcpu_svm *svm) = {
+	int (*const svm_instr_handlers[])(struct kvm_vcpu *vcpu) = {
 		[SVM_INSTR_VMRUN] = vmrun_interception,
 		[SVM_INSTR_VMLOAD] = vmload_interception,
 		[SVM_INSTR_VMSAVE] = vmsave_interception,
@@ -2216,17 +2168,13 @@ static int emulate_svm_instr(struct kvm_vcpu *vcpu, int opcode)
 	int ret;
 
 	if (is_guest_mode(vcpu)) {
-		svm->vmcb->control.exit_code = guest_mode_exit_codes[opcode];
-		svm->vmcb->control.exit_info_1 = 0;
-		svm->vmcb->control.exit_info_2 = 0;
-
 		/* Returns '1' or -errno on failure, '0' on success. */
-		ret = nested_svm_vmexit(svm);
+		ret = nested_svm_simple_vmexit(svm, guest_mode_exit_codes[opcode]);
 		if (ret)
 			return ret;
 		return 1;
 	}
-	return svm_instr_handlers[opcode](svm);
+	return svm_instr_handlers[opcode](vcpu);
 }
 
 /*
@@ -2237,9 +2185,9 @@ static int emulate_svm_instr(struct kvm_vcpu *vcpu, int opcode)
  *      regions (e.g. SMM memory on host).
  *   2) VMware backdoor
  */
-static int gp_interception(struct vcpu_svm *svm)
+static int gp_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u32 error_code = svm->vmcb->control.exit_info_1;
 	int opcode;
 
@@ -2304,73 +2252,58 @@ void svm_set_gif(struct vcpu_svm *svm, bool value)
 	}
 }
 
-static int stgi_interception(struct vcpu_svm *svm)
+static int stgi_interception(struct kvm_vcpu *vcpu)
 {
 	int ret;
 
-	if (nested_svm_check_permissions(svm))
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
-	svm_set_gif(svm, true);
+	ret = kvm_skip_emulated_instruction(vcpu);
+	svm_set_gif(to_svm(vcpu), true);
 	return ret;
 }
 
-static int clgi_interception(struct vcpu_svm *svm)
+static int clgi_interception(struct kvm_vcpu *vcpu)
 {
 	int ret;
 
-	if (nested_svm_check_permissions(svm))
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
-	svm_set_gif(svm, false);
+	ret = kvm_skip_emulated_instruction(vcpu);
+	svm_set_gif(to_svm(vcpu), false);
 	return ret;
 }
 
-static int invlpga_interception(struct vcpu_svm *svm)
+static int invlpga_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
-
-	trace_kvm_invlpga(svm->vmcb->save.rip, kvm_rcx_read(&svm->vcpu),
-			  kvm_rax_read(&svm->vcpu));
-
-	/* Let's treat INVLPGA the same as INVLPG (can be optimized!) */
-	kvm_mmu_invlpg(vcpu, kvm_rax_read(&svm->vcpu));
+	gva_t gva = kvm_rax_read(vcpu);
+	u32 asid = kvm_rcx_read(vcpu);
 
-	return kvm_skip_emulated_instruction(&svm->vcpu);
-}
+	/* FIXME: Handle an address size prefix. */
+	if (!is_long_mode(vcpu))
+		gva = (u32)gva;
 
-static int skinit_interception(struct vcpu_svm *svm)
-{
-	trace_kvm_skinit(svm->vmcb->save.rip, kvm_rax_read(&svm->vcpu));
+	trace_kvm_invlpga(to_svm(vcpu)->vmcb->save.rip, asid, gva);
 
-	kvm_queue_exception(&svm->vcpu, UD_VECTOR);
-	return 1;
-}
+	/* Let's treat INVLPGA the same as INVLPG (can be optimized!) */
+	kvm_mmu_invlpg(vcpu, gva);
 
-static int wbinvd_interception(struct vcpu_svm *svm)
-{
-	return kvm_emulate_wbinvd(&svm->vcpu);
+	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int xsetbv_interception(struct vcpu_svm *svm)
+static int skinit_interception(struct kvm_vcpu *vcpu)
 {
-	u64 new_bv = kvm_read_edx_eax(&svm->vcpu);
-	u32 index = kvm_rcx_read(&svm->vcpu);
+	trace_kvm_skinit(to_svm(vcpu)->vmcb->save.rip, kvm_rax_read(vcpu));
 
-	int err = kvm_set_xcr(&svm->vcpu, index, new_bv);
-	return kvm_complete_insn_gp(&svm->vcpu, err);
-}
-
-static int rdpru_interception(struct vcpu_svm *svm)
-{
-	kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+	kvm_queue_exception(vcpu, UD_VECTOR);
 	return 1;
 }
 
-static int task_switch_interception(struct vcpu_svm *svm)
+static int task_switch_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u16 tss_selector;
 	int reason;
 	int int_type = svm->vmcb->control.exit_int_info &
@@ -2399,7 +2332,7 @@ static int task_switch_interception(struct vcpu_svm *svm)
 	if (reason == TASK_SWITCH_GATE) {
 		switch (type) {
 		case SVM_EXITINTINFO_TYPE_NMI:
-			svm->vcpu.arch.nmi_injected = false;
+			vcpu->arch.nmi_injected = false;
 			break;
 		case SVM_EXITINTINFO_TYPE_EXEPT:
 			if (svm->vmcb->control.exit_info_2 &
@@ -2408,10 +2341,10 @@ static int task_switch_interception(struct vcpu_svm *svm)
 				error_code =
 					(u32)svm->vmcb->control.exit_info_2;
 			}
-			kvm_clear_exception_queue(&svm->vcpu);
+			kvm_clear_exception_queue(vcpu);
 			break;
 		case SVM_EXITINTINFO_TYPE_INTR:
-			kvm_clear_interrupt_queue(&svm->vcpu);
+			kvm_clear_interrupt_queue(vcpu);
 			break;
 		default:
 			break;
@@ -2422,77 +2355,58 @@ static int task_switch_interception(struct vcpu_svm *svm)
 	    int_type == SVM_EXITINTINFO_TYPE_SOFT ||
 	    (int_type == SVM_EXITINTINFO_TYPE_EXEPT &&
 	     (int_vec == OF_VECTOR || int_vec == BP_VECTOR))) {
-		if (!skip_emulated_instruction(&svm->vcpu))
+		if (!skip_emulated_instruction(vcpu))
 			return 0;
 	}
 
 	if (int_type != SVM_EXITINTINFO_TYPE_SOFT)
 		int_vec = -1;
 
-	return kvm_task_switch(&svm->vcpu, tss_selector, int_vec, reason,
+	return kvm_task_switch(vcpu, tss_selector, int_vec, reason,
 			       has_error_code, error_code);
 }
 
-static int cpuid_interception(struct vcpu_svm *svm)
+static int iret_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_cpuid(&svm->vcpu);
-}
+	struct vcpu_svm *svm = to_svm(vcpu);
 
-static int iret_interception(struct vcpu_svm *svm)
-{
-	++svm->vcpu.stat.nmi_window_exits;
-	svm->vcpu.arch.hflags |= HF_IRET_MASK;
-	if (!sev_es_guest(svm->vcpu.kvm)) {
+	++vcpu->stat.nmi_window_exits;
+	vcpu->arch.hflags |= HF_IRET_MASK;
+	if (!sev_es_guest(vcpu->kvm)) {
 		svm_clr_intercept(svm, INTERCEPT_IRET);
-		svm->nmi_iret_rip = kvm_rip_read(&svm->vcpu);
+		svm->nmi_iret_rip = kvm_rip_read(vcpu);
 	}
-	kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
+	kvm_make_request(KVM_REQ_EVENT, vcpu);
 	return 1;
 }
 
-static int invd_interception(struct vcpu_svm *svm)
-{
-	/* Treat an INVD instruction as a NOP and just skip it. */
-	return kvm_skip_emulated_instruction(&svm->vcpu);
-}
-
-static int invlpg_interception(struct vcpu_svm *svm)
+static int invlpg_interception(struct kvm_vcpu *vcpu)
 {
 	if (!static_cpu_has(X86_FEATURE_DECODEASSISTS))
-		return kvm_emulate_instruction(&svm->vcpu, 0);
+		return kvm_emulate_instruction(vcpu, 0);
 
-	kvm_mmu_invlpg(&svm->vcpu, svm->vmcb->control.exit_info_1);
-	return kvm_skip_emulated_instruction(&svm->vcpu);
+	kvm_mmu_invlpg(vcpu, to_svm(vcpu)->vmcb->control.exit_info_1);
+	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int emulate_on_interception(struct vcpu_svm *svm)
+static int emulate_on_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_instruction(&svm->vcpu, 0);
+	return kvm_emulate_instruction(vcpu, 0);
 }
 
-static int rsm_interception(struct vcpu_svm *svm)
+static int rsm_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_instruction_from_buffer(&svm->vcpu, rsm_ins_bytes, 2);
+	return kvm_emulate_instruction_from_buffer(vcpu, rsm_ins_bytes, 2);
 }
 
-static int rdpmc_interception(struct vcpu_svm *svm)
-{
-	int err;
-
-	if (!nrips)
-		return emulate_on_interception(svm);
-
-	err = kvm_rdpmc(&svm->vcpu);
-	return kvm_complete_insn_gp(&svm->vcpu, err);
-}
-
-static bool check_selective_cr0_intercepted(struct vcpu_svm *svm,
+static bool check_selective_cr0_intercepted(struct kvm_vcpu *vcpu,
 					    unsigned long val)
 {
-	unsigned long cr0 = svm->vcpu.arch.cr0;
+	struct vcpu_svm *svm = to_svm(vcpu);
+	unsigned long cr0 = vcpu->arch.cr0;
 	bool ret = false;
 
-	if (!is_guest_mode(&svm->vcpu) ||
+	if (!is_guest_mode(vcpu) ||
 	    (!(vmcb_is_intercept(&svm->nested.ctl, INTERCEPT_SELECTIVE_CR0))))
 		return false;
 
@@ -2509,17 +2423,18 @@ static bool check_selective_cr0_intercepted(struct vcpu_svm *svm,
 
 #define CR_VALID (1ULL << 63)
 
-static int cr_interception(struct vcpu_svm *svm)
+static int cr_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int reg, cr;
 	unsigned long val;
 	int err;
 
 	if (!static_cpu_has(X86_FEATURE_DECODEASSISTS))
-		return emulate_on_interception(svm);
+		return emulate_on_interception(vcpu);
 
 	if (unlikely((svm->vmcb->control.exit_info_1 & CR_VALID) == 0))
-		return emulate_on_interception(svm);
+		return emulate_on_interception(vcpu);
 
 	reg = svm->vmcb->control.exit_info_1 & SVM_EXITINFO_REG_MASK;
 	if (svm->vmcb->control.exit_code == SVM_EXIT_CR0_SEL_WRITE)
@@ -2530,61 +2445,61 @@ static int cr_interception(struct vcpu_svm *svm)
 	err = 0;
 	if (cr >= 16) { /* mov to cr */
 		cr -= 16;
-		val = kvm_register_read(&svm->vcpu, reg);
+		val = kvm_register_read(vcpu, reg);
 		trace_kvm_cr_write(cr, val);
 		switch (cr) {
 		case 0:
-			if (!check_selective_cr0_intercepted(svm, val))
-				err = kvm_set_cr0(&svm->vcpu, val);
+			if (!check_selective_cr0_intercepted(vcpu, val))
+				err = kvm_set_cr0(vcpu, val);
 			else
 				return 1;
 
 			break;
 		case 3:
-			err = kvm_set_cr3(&svm->vcpu, val);
+			err = kvm_set_cr3(vcpu, val);
 			break;
 		case 4:
-			err = kvm_set_cr4(&svm->vcpu, val);
+			err = kvm_set_cr4(vcpu, val);
 			break;
 		case 8:
-			err = kvm_set_cr8(&svm->vcpu, val);
+			err = kvm_set_cr8(vcpu, val);
 			break;
 		default:
 			WARN(1, "unhandled write to CR%d", cr);
-			kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+			kvm_queue_exception(vcpu, UD_VECTOR);
 			return 1;
 		}
 	} else { /* mov from cr */
 		switch (cr) {
 		case 0:
-			val = kvm_read_cr0(&svm->vcpu);
+			val = kvm_read_cr0(vcpu);
 			break;
 		case 2:
-			val = svm->vcpu.arch.cr2;
+			val = vcpu->arch.cr2;
 			break;
 		case 3:
-			val = kvm_read_cr3(&svm->vcpu);
+			val = kvm_read_cr3(vcpu);
 			break;
 		case 4:
-			val = kvm_read_cr4(&svm->vcpu);
+			val = kvm_read_cr4(vcpu);
 			break;
 		case 8:
-			val = kvm_get_cr8(&svm->vcpu);
+			val = kvm_get_cr8(vcpu);
 			break;
 		default:
 			WARN(1, "unhandled read from CR%d", cr);
-			kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+			kvm_queue_exception(vcpu, UD_VECTOR);
 			return 1;
 		}
-		kvm_register_write(&svm->vcpu, reg, val);
+		kvm_register_write(vcpu, reg, val);
 		trace_kvm_cr_read(cr, val);
 	}
-	return kvm_complete_insn_gp(&svm->vcpu, err);
+	return kvm_complete_insn_gp(vcpu, err);
 }
 
-static int cr_trap(struct vcpu_svm *svm)
+static int cr_trap(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	unsigned long old_value, new_value;
 	unsigned int cr;
 	int ret = 0;
@@ -2606,7 +2521,7 @@ static int cr_trap(struct vcpu_svm *svm)
 		kvm_post_set_cr4(vcpu, old_value, new_value);
 		break;
 	case 8:
-		ret = kvm_set_cr8(&svm->vcpu, new_value);
+		ret = kvm_set_cr8(vcpu, new_value);
 		break;
 	default:
 		WARN(1, "unhandled CR%d write trap", cr);
@@ -2617,57 +2532,57 @@ static int cr_trap(struct vcpu_svm *svm)
 	return kvm_complete_insn_gp(vcpu, ret);
 }
 
-static int dr_interception(struct vcpu_svm *svm)
+static int dr_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int reg, dr;
 	unsigned long val;
 	int err = 0;
 
-	if (svm->vcpu.guest_debug == 0) {
+	if (vcpu->guest_debug == 0) {
 		/*
 		 * No more DR vmexits; force a reload of the debug registers
 		 * and reenter on this instruction.  The next vmexit will
 		 * retrieve the full state of the debug registers.
 		 */
 		clr_dr_intercepts(svm);
-		svm->vcpu.arch.switch_db_regs |= KVM_DEBUGREG_WONT_EXIT;
+		vcpu->arch.switch_db_regs |= KVM_DEBUGREG_WONT_EXIT;
 		return 1;
 	}
 
 	if (!boot_cpu_has(X86_FEATURE_DECODEASSISTS))
-		return emulate_on_interception(svm);
+		return emulate_on_interception(vcpu);
 
 	reg = svm->vmcb->control.exit_info_1 & SVM_EXITINFO_REG_MASK;
 	dr = svm->vmcb->control.exit_code - SVM_EXIT_READ_DR0;
 	if (dr >= 16) { /* mov to DRn  */
 		dr -= 16;
-		val = kvm_register_read(&svm->vcpu, reg);
-		err = kvm_set_dr(&svm->vcpu, dr, val);
+		val = kvm_register_read(vcpu, reg);
+		err = kvm_set_dr(vcpu, dr, val);
 	} else {
-		kvm_get_dr(&svm->vcpu, dr, &val);
-		kvm_register_write(&svm->vcpu, reg, val);
+		kvm_get_dr(vcpu, dr, &val);
+		kvm_register_write(vcpu, reg, val);
 	}
 
-	return kvm_complete_insn_gp(&svm->vcpu, err);
+	return kvm_complete_insn_gp(vcpu, err);
 }
 
-static int cr8_write_interception(struct vcpu_svm *svm)
+static int cr8_write_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
 	int r;
 
-	u8 cr8_prev = kvm_get_cr8(&svm->vcpu);
+	u8 cr8_prev = kvm_get_cr8(vcpu);
 	/* instruction emulation calls kvm_set_cr8() */
-	r = cr_interception(svm);
-	if (lapic_in_kernel(&svm->vcpu))
+	r = cr_interception(vcpu);
+	if (lapic_in_kernel(vcpu))
 		return r;
-	if (cr8_prev <= kvm_get_cr8(&svm->vcpu))
+	if (cr8_prev <= kvm_get_cr8(vcpu))
 		return r;
-	kvm_run->exit_reason = KVM_EXIT_SET_TPR;
+	vcpu->run->exit_reason = KVM_EXIT_SET_TPR;
 	return 0;
 }
 
-static int efer_trap(struct vcpu_svm *svm)
+static int efer_trap(struct kvm_vcpu *vcpu)
 {
 	struct msr_data msr_info;
 	int ret;
@@ -2680,10 +2595,10 @@ static int efer_trap(struct vcpu_svm *svm)
 	 */
 	msr_info.host_initiated = false;
 	msr_info.index = MSR_EFER;
-	msr_info.data = svm->vmcb->control.exit_info_1 & ~EFER_SVME;
-	ret = kvm_set_msr_common(&svm->vcpu, &msr_info);
+	msr_info.data = to_svm(vcpu)->vmcb->control.exit_info_1 & ~EFER_SVME;
+	ret = kvm_set_msr_common(vcpu, &msr_info);
 
-	return kvm_complete_insn_gp(&svm->vcpu, ret);
+	return kvm_complete_insn_gp(vcpu, ret);
 }
 
 static int svm_get_msr_feature(struct kvm_msr_entry *msr)
@@ -2710,34 +2625,41 @@ static int svm_get_msr(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
 
 	switch (msr_info->index) {
 	case MSR_STAR:
-		msr_info->data = svm->vmcb->save.star;
+		msr_info->data = svm->vmcb01.ptr->save.star;
 		break;
 #ifdef CONFIG_X86_64
 	case MSR_LSTAR:
-		msr_info->data = svm->vmcb->save.lstar;
+		msr_info->data = svm->vmcb01.ptr->save.lstar;
 		break;
 	case MSR_CSTAR:
-		msr_info->data = svm->vmcb->save.cstar;
+		msr_info->data = svm->vmcb01.ptr->save.cstar;
 		break;
 	case MSR_KERNEL_GS_BASE:
-		msr_info->data = svm->vmcb->save.kernel_gs_base;
+		msr_info->data = svm->vmcb01.ptr->save.kernel_gs_base;
 		break;
 	case MSR_SYSCALL_MASK:
-		msr_info->data = svm->vmcb->save.sfmask;
+		msr_info->data = svm->vmcb01.ptr->save.sfmask;
 		break;
 #endif
 	case MSR_IA32_SYSENTER_CS:
-		msr_info->data = svm->vmcb->save.sysenter_cs;
+		msr_info->data = svm->vmcb01.ptr->save.sysenter_cs;
 		break;
 	case MSR_IA32_SYSENTER_EIP:
-		msr_info->data = svm->sysenter_eip;
+		msr_info->data = (u32)svm->vmcb01.ptr->save.sysenter_eip;
+		if (guest_cpuid_is_intel(vcpu))
+			msr_info->data |= (u64)svm->sysenter_eip_hi << 32;
 		break;
 	case MSR_IA32_SYSENTER_ESP:
-		msr_info->data = svm->sysenter_esp;
+		msr_info->data = svm->vmcb01.ptr->save.sysenter_esp;
+		if (guest_cpuid_is_intel(vcpu))
+			msr_info->data |= (u64)svm->sysenter_esp_hi << 32;
 		break;
 	case MSR_TSC_AUX:
 		if (!boot_cpu_has(X86_FEATURE_RDTSCP))
 			return 1;
+		if (!msr_info->host_initiated &&
+		    !guest_cpuid_has(vcpu, X86_FEATURE_RDTSCP))
+			return 1;
 		msr_info->data = svm->tsc_aux;
 		break;
 	/*
@@ -2771,7 +2693,10 @@ static int svm_get_msr(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
 		    !guest_has_spec_ctrl_msr(vcpu))
 			return 1;
 
-		msr_info->data = svm->spec_ctrl;
+		if (boot_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+			msr_info->data = svm->vmcb->save.spec_ctrl;
+		else
+			msr_info->data = svm->spec_ctrl;
 		break;
 	case MSR_AMD64_VIRT_SPEC_CTRL:
 		if (!msr_info->host_initiated &&
@@ -2809,8 +2734,8 @@ static int svm_get_msr(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
 static int svm_complete_emulated_msr(struct kvm_vcpu *vcpu, int err)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
-	if (!sev_es_guest(svm->vcpu.kvm) || !err)
-		return kvm_complete_insn_gp(&svm->vcpu, err);
+	if (!err || !sev_es_guest(vcpu->kvm) || WARN_ON_ONCE(!svm->ghcb))
+		return kvm_complete_insn_gp(vcpu, err);
 
 	ghcb_set_sw_exit_info_1(svm->ghcb, 1);
 	ghcb_set_sw_exit_info_2(svm->ghcb,
@@ -2820,11 +2745,6 @@ static int svm_complete_emulated_msr(struct kvm_vcpu *vcpu, int err)
 	return 1;
 }
 
-static int rdmsr_interception(struct vcpu_svm *svm)
-{
-	return kvm_emulate_rdmsr(&svm->vcpu);
-}
-
 static int svm_set_vm_cr(struct kvm_vcpu *vcpu, u64 data)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
@@ -2853,6 +2773,7 @@ static int svm_set_vm_cr(struct kvm_vcpu *vcpu, u64 data)
 static int svm_set_msr(struct kvm_vcpu *vcpu, struct msr_data *msr)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
+	int r;
 
 	u32 ecx = msr->index;
 	u64 data = msr->data;
@@ -2861,7 +2782,9 @@ static int svm_set_msr(struct kvm_vcpu *vcpu, struct msr_data *msr)
 		if (!kvm_mtrr_valid(vcpu, MSR_IA32_CR_PAT, data))
 			return 1;
 		vcpu->arch.pat = data;
-		svm->vmcb->save.g_pat = data;
+		svm->vmcb01.ptr->save.g_pat = data;
+		if (is_guest_mode(vcpu))
+			nested_vmcb02_compute_g_pat(svm);
 		vmcb_mark_dirty(svm->vmcb, VMCB_NPT);
 		break;
 	case MSR_IA32_SPEC_CTRL:
@@ -2872,7 +2795,10 @@ static int svm_set_msr(struct kvm_vcpu *vcpu, struct msr_data *msr)
 		if (kvm_spec_ctrl_test_value(data))
 			return 1;
 
-		svm->spec_ctrl = data;
+		if (boot_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+			svm->vmcb->save.spec_ctrl = data;
+		else
+			svm->spec_ctrl = data;
 		if (!data)
 			break;
 
@@ -2915,44 +2841,70 @@ static int svm_set_msr(struct kvm_vcpu *vcpu, struct msr_data *msr)
 		svm->virt_spec_ctrl = data;
 		break;
 	case MSR_STAR:
-		svm->vmcb->save.star = data;
+		svm->vmcb01.ptr->save.star = data;
 		break;
 #ifdef CONFIG_X86_64
 	case MSR_LSTAR:
-		svm->vmcb->save.lstar = data;
+		svm->vmcb01.ptr->save.lstar = data;
 		break;
 	case MSR_CSTAR:
-		svm->vmcb->save.cstar = data;
+		svm->vmcb01.ptr->save.cstar = data;
 		break;
 	case MSR_KERNEL_GS_BASE:
-		svm->vmcb->save.kernel_gs_base = data;
+		svm->vmcb01.ptr->save.kernel_gs_base = data;
 		break;
 	case MSR_SYSCALL_MASK:
-		svm->vmcb->save.sfmask = data;
+		svm->vmcb01.ptr->save.sfmask = data;
 		break;
 #endif
 	case MSR_IA32_SYSENTER_CS:
-		svm->vmcb->save.sysenter_cs = data;
+		svm->vmcb01.ptr->save.sysenter_cs = data;
 		break;
 	case MSR_IA32_SYSENTER_EIP:
-		svm->sysenter_eip = data;
-		svm->vmcb->save.sysenter_eip = data;
+		svm->vmcb01.ptr->save.sysenter_eip = (u32)data;
+		/*
+		 * We only intercept the MSR_IA32_SYSENTER_{EIP|ESP} msrs
+		 * when we spoof an Intel vendor ID (for cross vendor migration).
+		 * In this case we use this intercept to track the high
+		 * 32 bit part of these msrs to support Intel's
+		 * implementation of SYSENTER/SYSEXIT.
+		 */
+		svm->sysenter_eip_hi = guest_cpuid_is_intel(vcpu) ? (data >> 32) : 0;
 		break;
 	case MSR_IA32_SYSENTER_ESP:
-		svm->sysenter_esp = data;
-		svm->vmcb->save.sysenter_esp = data;
+		svm->vmcb01.ptr->save.sysenter_esp = (u32)data;
+		svm->sysenter_esp_hi = guest_cpuid_is_intel(vcpu) ? (data >> 32) : 0;
 		break;
 	case MSR_TSC_AUX:
 		if (!boot_cpu_has(X86_FEATURE_RDTSCP))
 			return 1;
 
+		if (!msr->host_initiated &&
+		    !guest_cpuid_has(vcpu, X86_FEATURE_RDTSCP))
+			return 1;
+
+		/*
+		 * Per Intel's SDM, bits 63:32 are reserved, but AMD's APM has
+		 * incomplete and conflicting architectural behavior.  Current
+		 * AMD CPUs completely ignore bits 63:32, i.e. they aren't
+		 * reserved and always read as zeros.  Emulate AMD CPU behavior
+		 * to avoid explosions if the vCPU is migrated from an AMD host
+		 * to an Intel host.
+		 */
+		data = (u32)data;
+
 		/*
-		 * This is rare, so we update the MSR here instead of using
-		 * direct_access_msrs.  Doing that would require a rdmsr in
-		 * svm_vcpu_put.
+		 * TSC_AUX is usually changed only during boot and never read
+		 * directly.  Intercept TSC_AUX instead of exposing it to the
+		 * guest via direct_access_msrs, and switch it via user return.
 		 */
+		preempt_disable();
+		r = kvm_set_user_return_msr(TSC_AUX_URET_SLOT, data, -1ull);
+		preempt_enable();
+		if (r)
+			return 1;
+
 		svm->tsc_aux = data;
-		wrmsrl(MSR_TSC_AUX, svm->tsc_aux);
 		break;
 	case MSR_IA32_DEBUGCTLMSR:
 		if (!boot_cpu_has(X86_FEATURE_LBRV)) {
@@ -3006,38 +2958,32 @@ static int svm_set_msr(struct kvm_vcpu *vcpu, struct msr_data *msr)
 	return 0;
 }
 
-static int wrmsr_interception(struct vcpu_svm *svm)
-{
-	return kvm_emulate_wrmsr(&svm->vcpu);
-}
-
-static int msr_interception(struct vcpu_svm *svm)
+static int msr_interception(struct kvm_vcpu *vcpu)
 {
-	if (svm->vmcb->control.exit_info_1)
-		return wrmsr_interception(svm);
+	if (to_svm(vcpu)->vmcb->control.exit_info_1)
+		return kvm_emulate_wrmsr(vcpu);
 	else
-		return rdmsr_interception(svm);
+		return kvm_emulate_rdmsr(vcpu);
 }
 
-static int interrupt_window_interception(struct vcpu_svm *svm)
+static int interrupt_window_interception(struct kvm_vcpu *vcpu)
 {
-	kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
-	svm_clear_vintr(svm);
+	kvm_make_request(KVM_REQ_EVENT, vcpu);
+	svm_clear_vintr(to_svm(vcpu));
 
 	/*
 	 * For AVIC, the only reason to end up here is ExtINTs.
 	 * In this case AVIC was temporarily disabled for
 	 * requesting the IRQ window and we have to re-enable it.
 	 */
-	svm_toggle_avic_for_irq_window(&svm->vcpu, true);
+	svm_toggle_avic_for_irq_window(vcpu, true);
 
-	++svm->vcpu.stat.irq_window_exits;
+	++vcpu->stat.irq_window_exits;
 	return 1;
 }
 
-static int pause_interception(struct vcpu_svm *svm)
+static int pause_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
 	bool in_kernel;
 
 	/*
@@ -3045,35 +2991,18 @@ static int pause_interception(struct vcpu_svm *svm)
 	 * vcpu->arch.preempted_in_kernel can never be true.  Just
 	 * set in_kernel to false as well.
 	 */
-	in_kernel = !sev_es_guest(svm->vcpu.kvm) && svm_get_cpl(vcpu) == 0;
+	in_kernel = !sev_es_guest(vcpu->kvm) && svm_get_cpl(vcpu) == 0;
 
 	if (!kvm_pause_in_guest(vcpu->kvm))
 		grow_ple_window(vcpu);
 
 	kvm_vcpu_on_spin(vcpu, in_kernel);
-	return 1;
-}
-
-static int nop_interception(struct vcpu_svm *svm)
-{
-	return kvm_skip_emulated_instruction(&(svm->vcpu));
+	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int monitor_interception(struct vcpu_svm *svm)
+static int invpcid_interception(struct kvm_vcpu *vcpu)
 {
-	printk_once(KERN_WARNING "kvm: MONITOR instruction emulated as NOP!\n");
-	return nop_interception(svm);
-}
-
-static int mwait_interception(struct vcpu_svm *svm)
-{
-	printk_once(KERN_WARNING "kvm: MWAIT instruction emulated as NOP!\n");
-	return nop_interception(svm);
-}
-
-static int invpcid_interception(struct vcpu_svm *svm)
-{
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	unsigned long type;
 	gva_t gva;
 
@@ -3098,7 +3027,7 @@ static int invpcid_interception(struct vcpu_svm *svm)
 	return kvm_handle_invpcid(vcpu, type, gva);
 }
 
-static int (*const svm_exit_handlers[])(struct vcpu_svm *svm) = {
+static int (*const svm_exit_handlers[])(struct kvm_vcpu *vcpu) = {
 	[SVM_EXIT_READ_CR0]			= cr_interception,
 	[SVM_EXIT_READ_CR3]			= cr_interception,
 	[SVM_EXIT_READ_CR4]			= cr_interception,
@@ -3133,15 +3062,15 @@ static int (*const svm_exit_handlers[])(struct vcpu_svm *svm) = {
 	[SVM_EXIT_EXCP_BASE + GP_VECTOR]	= gp_interception,
 	[SVM_EXIT_INTR]				= intr_interception,
 	[SVM_EXIT_NMI]				= nmi_interception,
-	[SVM_EXIT_SMI]				= nop_on_interception,
-	[SVM_EXIT_INIT]				= nop_on_interception,
+	[SVM_EXIT_SMI]				= kvm_emulate_as_nop,
+	[SVM_EXIT_INIT]				= kvm_emulate_as_nop,
 	[SVM_EXIT_VINTR]			= interrupt_window_interception,
-	[SVM_EXIT_RDPMC]			= rdpmc_interception,
-	[SVM_EXIT_CPUID]			= cpuid_interception,
+	[SVM_EXIT_RDPMC]			= kvm_emulate_rdpmc,
+	[SVM_EXIT_CPUID]			= kvm_emulate_cpuid,
 	[SVM_EXIT_IRET]                         = iret_interception,
-	[SVM_EXIT_INVD]                         = invd_interception,
+	[SVM_EXIT_INVD]                         = kvm_emulate_invd,
 	[SVM_EXIT_PAUSE]			= pause_interception,
-	[SVM_EXIT_HLT]				= halt_interception,
+	[SVM_EXIT_HLT]				= kvm_emulate_halt,
 	[SVM_EXIT_INVLPG]			= invlpg_interception,
 	[SVM_EXIT_INVLPGA]			= invlpga_interception,
 	[SVM_EXIT_IOIO]				= io_interception,
@@ -3149,17 +3078,17 @@ static int (*const svm_exit_handlers[])(struct vcpu_svm *svm) = {
 	[SVM_EXIT_TASK_SWITCH]			= task_switch_interception,
 	[SVM_EXIT_SHUTDOWN]			= shutdown_interception,
 	[SVM_EXIT_VMRUN]			= vmrun_interception,
-	[SVM_EXIT_VMMCALL]			= vmmcall_interception,
+	[SVM_EXIT_VMMCALL]			= kvm_emulate_hypercall,
 	[SVM_EXIT_VMLOAD]			= vmload_interception,
 	[SVM_EXIT_VMSAVE]			= vmsave_interception,
 	[SVM_EXIT_STGI]				= stgi_interception,
 	[SVM_EXIT_CLGI]				= clgi_interception,
 	[SVM_EXIT_SKINIT]			= skinit_interception,
-	[SVM_EXIT_WBINVD]                       = wbinvd_interception,
-	[SVM_EXIT_MONITOR]			= monitor_interception,
-	[SVM_EXIT_MWAIT]			= mwait_interception,
-	[SVM_EXIT_XSETBV]			= xsetbv_interception,
-	[SVM_EXIT_RDPRU]			= rdpru_interception,
+	[SVM_EXIT_WBINVD]                       = kvm_emulate_wbinvd,
+	[SVM_EXIT_MONITOR]			= kvm_emulate_monitor,
+	[SVM_EXIT_MWAIT]			= kvm_emulate_mwait,
+	[SVM_EXIT_XSETBV]			= kvm_emulate_xsetbv,
+	[SVM_EXIT_RDPRU]			= kvm_handle_invalid_op,
 	[SVM_EXIT_EFER_WRITE_TRAP]		= efer_trap,
 	[SVM_EXIT_CR0_WRITE_TRAP]		= cr_trap,
 	[SVM_EXIT_CR4_WRITE_TRAP]		= cr_trap,
@@ -3177,6 +3106,7 @@ static void dump_vmcb(struct kvm_vcpu *vcpu)
 	struct vcpu_svm *svm = to_svm(vcpu);
 	struct vmcb_control_area *control = &svm->vmcb->control;
 	struct vmcb_save_area *save = &svm->vmcb->save;
+	struct vmcb_save_area *save01 = &svm->vmcb01.ptr->save;
 
 	if (!dump_invalid_vmcb) {
 		pr_warn_ratelimited("set kvm_amd.dump_invalid_vmcb=1 to dump internal KVM state.\n");
@@ -3239,28 +3169,28 @@ static void dump_vmcb(struct kvm_vcpu *vcpu)
 	       save->ds.limit, save->ds.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "fs:",
-	       save->fs.selector, save->fs.attrib,
-	       save->fs.limit, save->fs.base);
+	       save01->fs.selector, save01->fs.attrib,
+	       save01->fs.limit, save01->fs.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "gs:",
-	       save->gs.selector, save->gs.attrib,
-	       save->gs.limit, save->gs.base);
+	       save01->gs.selector, save01->gs.attrib,
+	       save01->gs.limit, save01->gs.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "gdtr:",
 	       save->gdtr.selector, save->gdtr.attrib,
 	       save->gdtr.limit, save->gdtr.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "ldtr:",
-	       save->ldtr.selector, save->ldtr.attrib,
-	       save->ldtr.limit, save->ldtr.base);
+	       save01->ldtr.selector, save01->ldtr.attrib,
+	       save01->ldtr.limit, save01->ldtr.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "idtr:",
 	       save->idtr.selector, save->idtr.attrib,
 	       save->idtr.limit, save->idtr.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "tr:",
-	       save->tr.selector, save->tr.attrib,
-	       save->tr.limit, save->tr.base);
+	       save01->tr.selector, save01->tr.attrib,
+	       save01->tr.limit, save01->tr.base);
 	pr_err("cpl:            %d                efer:         %016llx\n",
 		save->cpl, save->efer);
 	pr_err("%-15s %016llx %-13s %016llx\n",
@@ -3274,15 +3204,15 @@ static void dump_vmcb(struct kvm_vcpu *vcpu)
 	pr_err("%-15s %016llx %-13s %016llx\n",
 	       "rsp:", save->rsp, "rax:", save->rax);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "star:", save->star, "lstar:", save->lstar);
+	       "star:", save01->star, "lstar:", save01->lstar);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "cstar:", save->cstar, "sfmask:", save->sfmask);
+	       "cstar:", save01->cstar, "sfmask:", save01->sfmask);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "kernel_gs_base:", save->kernel_gs_base,
-	       "sysenter_cs:", save->sysenter_cs);
+	       "kernel_gs_base:", save01->kernel_gs_base,
+	       "sysenter_cs:", save01->sysenter_cs);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "sysenter_esp:", save->sysenter_esp,
-	       "sysenter_eip:", save->sysenter_eip);
+	       "sysenter_esp:", save01->sysenter_esp,
+	       "sysenter_eip:", save01->sysenter_eip);
 	pr_err("%-15s %016llx %-13s %016llx\n",
 	       "gpat:", save->g_pat, "dbgctl:", save->dbgctl);
 	pr_err("%-15s %016llx %-13s %016llx\n",
@@ -3309,24 +3239,24 @@ static int svm_handle_invalid_exit(struct kvm_vcpu *vcpu, u64 exit_code)
 	return -EINVAL;
 }
 
-int svm_invoke_exit_handler(struct vcpu_svm *svm, u64 exit_code)
+int svm_invoke_exit_handler(struct kvm_vcpu *vcpu, u64 exit_code)
 {
-	if (svm_handle_invalid_exit(&svm->vcpu, exit_code))
+	if (svm_handle_invalid_exit(vcpu, exit_code))
 		return 0;
 
 #ifdef CONFIG_RETPOLINE
 	if (exit_code == SVM_EXIT_MSR)
-		return msr_interception(svm);
+		return msr_interception(vcpu);
 	else if (exit_code == SVM_EXIT_VINTR)
-		return interrupt_window_interception(svm);
+		return interrupt_window_interception(vcpu);
 	else if (exit_code == SVM_EXIT_INTR)
-		return intr_interception(svm);
+		return intr_interception(vcpu);
 	else if (exit_code == SVM_EXIT_HLT)
-		return halt_interception(svm);
+		return kvm_emulate_halt(vcpu);
 	else if (exit_code == SVM_EXIT_NPF)
-		return npf_interception(svm);
+		return npf_interception(vcpu);
 #endif
-	return svm_exit_handlers[exit_code](svm);
+	return svm_exit_handlers[exit_code](vcpu);
 }
 
 static void svm_get_exit_info(struct kvm_vcpu *vcpu, u64 *info1, u64 *info2,
@@ -3395,7 +3325,7 @@ static int handle_exit(struct kvm_vcpu *vcpu, fastpath_t exit_fastpath)
 	if (exit_fastpath != EXIT_FASTPATH_NONE)
 		return 1;
 
-	return svm_invoke_exit_handler(svm, exit_code);
+	return svm_invoke_exit_handler(vcpu, exit_code);
 }
 
 static void reload_tss(struct kvm_vcpu *vcpu)
@@ -3406,15 +3336,27 @@ static void reload_tss(struct kvm_vcpu *vcpu)
 	load_TR_desc();
 }
 
-static void pre_svm_run(struct vcpu_svm *svm)
+static void pre_svm_run(struct kvm_vcpu *vcpu)
 {
-	struct svm_cpu_data *sd = per_cpu(svm_data, svm->vcpu.cpu);
+	struct svm_cpu_data *sd = per_cpu(svm_data, vcpu->cpu);
+	struct vcpu_svm *svm = to_svm(vcpu);
 
-	if (sev_guest(svm->vcpu.kvm))
-		return pre_sev_run(svm, svm->vcpu.cpu);
+	/*
+	 * If the previous vmrun of the vmcb occurred on a different physical
+	 * cpu, then mark the vmcb dirty and assign a new asid.  Hardware's
+	 * vmcb clean bits are per logical CPU, as are KVM's asid assignments.
+	 */
+	if (unlikely(svm->current_vmcb->cpu != vcpu->cpu)) {
+		svm->current_vmcb->asid_generation = 0;
+		vmcb_mark_all_dirty(svm->vmcb);
+		svm->current_vmcb->cpu = vcpu->cpu;
+        }
+
+	if (sev_guest(vcpu->kvm))
+		return pre_sev_run(svm, vcpu->cpu);
 
 	/* FIXME: handle wraparound of asid_generation */
-	if (svm->asid_generation != sd->asid_generation)
+	if (svm->current_vmcb->asid_generation != sd->asid_generation)
 		new_asid(svm, sd);
 }
 
@@ -3424,7 +3366,7 @@ static void svm_inject_nmi(struct kvm_vcpu *vcpu)
 
 	svm->vmcb->control.event_inj = SVM_EVTINJ_VALID | SVM_EVTINJ_TYPE_NMI;
 	vcpu->arch.hflags |= HF_NMI_MASK;
-	if (!sev_es_guest(svm->vcpu.kvm))
+	if (!sev_es_guest(vcpu->kvm))
 		svm_set_intercept(svm, INTERCEPT_IRET);
 	++vcpu->stat.nmi_injections;
 }
@@ -3478,7 +3420,7 @@ bool svm_nmi_blocked(struct kvm_vcpu *vcpu)
 		return false;
 
 	ret = (vmcb->control.int_state & SVM_INTERRUPT_SHADOW_MASK) ||
-	      (svm->vcpu.arch.hflags & HF_NMI_MASK);
+	      (vcpu->arch.hflags & HF_NMI_MASK);
 
 	return ret;
 }
@@ -3498,9 +3440,7 @@ static int svm_nmi_allowed(struct kvm_vcpu *vcpu, bool for_injection)
 
 static bool svm_get_nmi_mask(struct kvm_vcpu *vcpu)
 {
-	struct vcpu_svm *svm = to_svm(vcpu);
-
-	return !!(svm->vcpu.arch.hflags & HF_NMI_MASK);
+	return !!(vcpu->arch.hflags & HF_NMI_MASK);
 }
 
 static void svm_set_nmi_mask(struct kvm_vcpu *vcpu, bool masked)
@@ -3508,12 +3448,12 @@ static void svm_set_nmi_mask(struct kvm_vcpu *vcpu, bool masked)
 	struct vcpu_svm *svm = to_svm(vcpu);
 
 	if (masked) {
-		svm->vcpu.arch.hflags |= HF_NMI_MASK;
-		if (!sev_es_guest(svm->vcpu.kvm))
+		vcpu->arch.hflags |= HF_NMI_MASK;
+		if (!sev_es_guest(vcpu->kvm))
 			svm_set_intercept(svm, INTERCEPT_IRET);
 	} else {
-		svm->vcpu.arch.hflags &= ~HF_NMI_MASK;
-		if (!sev_es_guest(svm->vcpu.kvm))
+		vcpu->arch.hflags &= ~HF_NMI_MASK;
+		if (!sev_es_guest(vcpu->kvm))
 			svm_clr_intercept(svm, INTERCEPT_IRET);
 	}
 }
@@ -3526,7 +3466,7 @@ bool svm_interrupt_blocked(struct kvm_vcpu *vcpu)
 	if (!gif_set(svm))
 		return true;
 
-	if (sev_es_guest(svm->vcpu.kvm)) {
+	if (sev_es_guest(vcpu->kvm)) {
 		/*
 		 * SEV-ES guests to not expose RFLAGS. Use the VMCB interrupt mask
 		 * bit to determine the state of the IF flag.
@@ -3536,7 +3476,7 @@ bool svm_interrupt_blocked(struct kvm_vcpu *vcpu)
 	} else if (is_guest_mode(vcpu)) {
 		/* As long as interrupts are being delivered...  */
 		if ((svm->nested.ctl.int_ctl & V_INTR_MASKING_MASK)
-		    ? !(svm->nested.hsave->save.rflags & X86_EFLAGS_IF)
+		    ? !(svm->vmcb01.ptr->save.rflags & X86_EFLAGS_IF)
 		    : !(kvm_get_rflags(vcpu) & X86_EFLAGS_IF))
 			return true;
 
@@ -3595,8 +3535,7 @@ static void svm_enable_nmi_window(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 
-	if ((svm->vcpu.arch.hflags & (HF_NMI_MASK | HF_IRET_MASK))
-	    == HF_NMI_MASK)
+	if ((vcpu->arch.hflags & (HF_NMI_MASK | HF_IRET_MASK)) == HF_NMI_MASK)
 		return; /* IRET will cause a vm exit */
 
 	if (!gif_set(svm)) {
@@ -3638,7 +3577,7 @@ void svm_flush_tlb(struct kvm_vcpu *vcpu)
 	if (static_cpu_has(X86_FEATURE_FLUSHBYASID))
 		svm->vmcb->control.tlb_ctl = TLB_CONTROL_FLUSH_ASID;
 	else
-		svm->asid_generation--;
+		svm->current_vmcb->asid_generation--;
 }
 
 static void svm_flush_tlb_gva(struct kvm_vcpu *vcpu, gva_t gva)
@@ -3675,8 +3614,9 @@ static inline void sync_lapic_to_cr8(struct kvm_vcpu *vcpu)
 	svm->vmcb->control.int_ctl |= cr8 & V_TPR_MASK;
 }
 
-static void svm_complete_interrupts(struct vcpu_svm *svm)
+static void svm_complete_interrupts(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u8 vector;
 	int type;
 	u32 exitintinfo = svm->vmcb->control.exit_int_info;
@@ -3688,28 +3628,28 @@ static void svm_complete_interrupts(struct vcpu_svm *svm)
 	 * If we've made progress since setting HF_IRET_MASK, we've
 	 * executed an IRET and can allow NMI injection.
 	 */
-	if ((svm->vcpu.arch.hflags & HF_IRET_MASK) &&
-	    (sev_es_guest(svm->vcpu.kvm) ||
-	     kvm_rip_read(&svm->vcpu) != svm->nmi_iret_rip)) {
-		svm->vcpu.arch.hflags &= ~(HF_NMI_MASK | HF_IRET_MASK);
-		kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
+	if ((vcpu->arch.hflags & HF_IRET_MASK) &&
+	    (sev_es_guest(vcpu->kvm) ||
+	     kvm_rip_read(vcpu) != svm->nmi_iret_rip)) {
+		vcpu->arch.hflags &= ~(HF_NMI_MASK | HF_IRET_MASK);
+		kvm_make_request(KVM_REQ_EVENT, vcpu);
 	}
 
-	svm->vcpu.arch.nmi_injected = false;
-	kvm_clear_exception_queue(&svm->vcpu);
-	kvm_clear_interrupt_queue(&svm->vcpu);
+	vcpu->arch.nmi_injected = false;
+	kvm_clear_exception_queue(vcpu);
+	kvm_clear_interrupt_queue(vcpu);
 
 	if (!(exitintinfo & SVM_EXITINTINFO_VALID))
 		return;
 
-	kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
+	kvm_make_request(KVM_REQ_EVENT, vcpu);
 
 	vector = exitintinfo & SVM_EXITINTINFO_VEC_MASK;
 	type = exitintinfo & SVM_EXITINTINFO_TYPE_MASK;
 
 	switch (type) {
 	case SVM_EXITINTINFO_TYPE_NMI:
-		svm->vcpu.arch.nmi_injected = true;
+		vcpu->arch.nmi_injected = true;
 		break;
 	case SVM_EXITINTINFO_TYPE_EXEPT:
 		/*
@@ -3725,21 +3665,20 @@ static void svm_complete_interrupts(struct vcpu_svm *svm)
 		 */
 		if (kvm_exception_is_soft(vector)) {
 			if (vector == BP_VECTOR && int3_injected &&
-			    kvm_is_linear_rip(&svm->vcpu, svm->int3_rip))
-				kvm_rip_write(&svm->vcpu,
-					      kvm_rip_read(&svm->vcpu) -
-					      int3_injected);
+			    kvm_is_linear_rip(vcpu, svm->int3_rip))
+				kvm_rip_write(vcpu,
+					      kvm_rip_read(vcpu) - int3_injected);
 			break;
 		}
 		if (exitintinfo & SVM_EXITINTINFO_VALID_ERR) {
 			u32 err = svm->vmcb->control.exit_int_info_err;
-			kvm_requeue_exception_e(&svm->vcpu, vector, err);
+			kvm_requeue_exception_e(vcpu, vector, err);
 
 		} else
-			kvm_requeue_exception(&svm->vcpu, vector);
+			kvm_requeue_exception(vcpu, vector);
 		break;
 	case SVM_EXITINTINFO_TYPE_INTR:
-		kvm_queue_interrupt(&svm->vcpu, vector, false);
+		kvm_queue_interrupt(vcpu, vector, false);
 		break;
 	default:
 		break;
@@ -3754,7 +3693,7 @@ static void svm_cancel_injection(struct kvm_vcpu *vcpu)
 	control->exit_int_info = control->event_inj;
 	control->exit_int_info_err = control->event_inj_err;
 	control->event_inj = 0;
-	svm_complete_interrupts(svm);
+	svm_complete_interrupts(vcpu);
 }
 
 static fastpath_t svm_exit_handlers_fastpath(struct kvm_vcpu *vcpu)
@@ -3766,9 +3705,11 @@ static fastpath_t svm_exit_handlers_fastpath(struct kvm_vcpu *vcpu)
 	return EXIT_FASTPATH_NONE;
 }
 
-static noinstr void svm_vcpu_enter_exit(struct kvm_vcpu *vcpu,
-					struct vcpu_svm *svm)
+static noinstr void svm_vcpu_enter_exit(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
+	unsigned long vmcb_pa = svm->current_vmcb->pa;
+
 	/*
 	 * VMENTER enables interrupts (host state), but the kernel state is
 	 * interrupts disabled when this is invoked. Also tell RCU about
@@ -3789,12 +3730,20 @@ static noinstr void svm_vcpu_enter_exit(struct kvm_vcpu *vcpu,
 	guest_enter_irqoff();
 	lockdep_hardirqs_on(CALLER_ADDR0);
 
-	if (sev_es_guest(svm->vcpu.kvm)) {
-		__svm_sev_es_vcpu_run(svm->vmcb_pa);
+	if (sev_es_guest(vcpu->kvm)) {
+		__svm_sev_es_vcpu_run(vmcb_pa);
 	} else {
 		struct svm_cpu_data *sd = per_cpu(svm_data, vcpu->cpu);
 
-		__svm_vcpu_run(svm->vmcb_pa, (unsigned long *)&svm->vcpu.arch.regs);
+		/*
+		 * Use a single vmcb (vmcb01 because it's always valid) for
+		 * context switching guest state via VMLOAD/VMSAVE, that way
+		 * the state doesn't need to be copied between vmcb01 and
+		 * vmcb02 when switching vmcbs for nested virtualization.
+		 */
+		vmload(svm->vmcb01.pa);
+		__svm_vcpu_run(vmcb_pa, (unsigned long *)&vcpu->arch.regs);
+		vmsave(svm->vmcb01.pa);
 
 		vmload(__sme_page_pa(sd->save_area));
 	}
@@ -3845,7 +3794,7 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 		smp_send_reschedule(vcpu->cpu);
 	}
 
-	pre_svm_run(svm);
+	pre_svm_run(vcpu);
 
 	sync_lapic_to_cr8(vcpu);
 
@@ -3859,7 +3808,7 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	 * Run with all-zero DR6 unless needed, so that we can get the exact cause
 	 * of a #DB.
 	 */
-	if (unlikely(svm->vcpu.arch.switch_db_regs & KVM_DEBUGREG_WONT_EXIT))
+	if (unlikely(vcpu->arch.switch_db_regs & KVM_DEBUGREG_WONT_EXIT))
 		svm_set_dr6(svm, vcpu->arch.dr6);
 	else
 		svm_set_dr6(svm, DR6_ACTIVE_LOW);
@@ -3875,9 +3824,10 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	 * is no need to worry about the conditional branch over the wrmsr
 	 * being speculatively taken.
 	 */
-	x86_spec_ctrl_set_guest(svm->spec_ctrl, svm->virt_spec_ctrl);
+	if (!static_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+		x86_spec_ctrl_set_guest(svm->spec_ctrl, svm->virt_spec_ctrl);
 
-	svm_vcpu_enter_exit(vcpu, svm);
+	svm_vcpu_enter_exit(vcpu);
 
 	/*
 	 * We do not use IBRS in the kernel. If this vCPU has used the
@@ -3894,15 +3844,17 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	 * If the L02 MSR bitmap does not intercept the MSR, then we need to
 	 * save it.
 	 */
-	if (unlikely(!msr_write_intercepted(vcpu, MSR_IA32_SPEC_CTRL)))
+	if (!static_cpu_has(X86_FEATURE_V_SPEC_CTRL) &&
+	    unlikely(!msr_write_intercepted(vcpu, MSR_IA32_SPEC_CTRL)))
 		svm->spec_ctrl = native_read_msr(MSR_IA32_SPEC_CTRL);
 
-	if (!sev_es_guest(svm->vcpu.kvm))
+	if (!sev_es_guest(vcpu->kvm))
 		reload_tss(vcpu);
 
-	x86_spec_ctrl_restore_host(svm->spec_ctrl, svm->virt_spec_ctrl);
+	if (!static_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+		x86_spec_ctrl_restore_host(svm->spec_ctrl, svm->virt_spec_ctrl);
 
-	if (!sev_es_guest(svm->vcpu.kvm)) {
+	if (!sev_es_guest(vcpu->kvm)) {
 		vcpu->arch.cr2 = svm->vmcb->save.cr2;
 		vcpu->arch.regs[VCPU_REGS_RAX] = svm->vmcb->save.rax;
 		vcpu->arch.regs[VCPU_REGS_RSP] = svm->vmcb->save.rsp;
@@ -3910,7 +3862,7 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	}
 
 	if (unlikely(svm->vmcb->control.exit_code == SVM_EXIT_NMI))
-		kvm_before_interrupt(&svm->vcpu);
+		kvm_before_interrupt(vcpu);
 
 	kvm_load_host_xsave_state(vcpu);
 	stgi();
@@ -3918,13 +3870,13 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	/* Any pending NMI will happen here */
 
 	if (unlikely(svm->vmcb->control.exit_code == SVM_EXIT_NMI))
-		kvm_after_interrupt(&svm->vcpu);
+		kvm_after_interrupt(vcpu);
 
 	sync_cr8_to_lapic(vcpu);
 
 	svm->next_rip = 0;
-	if (is_guest_mode(&svm->vcpu)) {
-		sync_nested_vmcb_control(svm);
+	if (is_guest_mode(vcpu)) {
+		nested_sync_control_from_vmcb02(svm);
 		svm->nested.nested_run_pending = 0;
 	}
 
@@ -3933,7 +3885,7 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 
 	/* if exit due to PF check for async PF */
 	if (svm->vmcb->control.exit_code == SVM_EXIT_EXCP_BASE + PF_VECTOR)
-		svm->vcpu.arch.apf.host_apf_flags =
+		vcpu->arch.apf.host_apf_flags =
 			kvm_read_and_reset_apf_flags();
 
 	if (npt_enabled) {
@@ -3947,9 +3899,9 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	 */
 	if (unlikely(svm->vmcb->control.exit_code ==
 		     SVM_EXIT_EXCP_BASE + MC_VECTOR))
-		svm_handle_mce(svm);
+		svm_handle_mce(vcpu);
 
-	svm_complete_interrupts(svm);
+	svm_complete_interrupts(vcpu);
 
 	if (is_guest_mode(vcpu))
 		return EXIT_FASTPATH_NONE;
@@ -3957,21 +3909,26 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	return svm_exit_handlers_fastpath(vcpu);
 }
 
-static void svm_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long root,
+static void svm_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa,
 			     int root_level)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 	unsigned long cr3;
 
-	cr3 = __sme_set(root);
 	if (npt_enabled) {
-		svm->vmcb->control.nested_cr3 = cr3;
+		svm->vmcb->control.nested_cr3 = __sme_set(root_hpa);
 		vmcb_mark_dirty(svm->vmcb, VMCB_NPT);
 
 		/* Loading L2's CR3 is handled by enter_svm_guest_mode.  */
 		if (!test_bit(VCPU_EXREG_CR3, (ulong *)&vcpu->arch.regs_avail))
 			return;
 		cr3 = vcpu->arch.cr3;
+	} else if (vcpu->arch.mmu->shadow_root_level >= PT64_ROOT_4LEVEL) {
+		cr3 = __sme_set(root_hpa) | kvm_get_active_pcid(vcpu);
+	} else {
+		/* PCID in the guest should be impossible with a 32-bit MMU. */
+		WARN_ON_ONCE(kvm_get_active_pcid(vcpu));
+		cr3 = root_hpa;
 	}
 
 	svm->vmcb->save.cr3 = cr3;
@@ -4048,7 +4005,7 @@ static void svm_vcpu_after_set_cpuid(struct kvm_vcpu *vcpu)
 
 	/* Update nrips enabled cache */
 	svm->nrips_enabled = kvm_cpu_cap_has(X86_FEATURE_NRIPS) &&
-			     guest_cpuid_has(&svm->vcpu, X86_FEATURE_NRIPS);
+			     guest_cpuid_has(vcpu, X86_FEATURE_NRIPS);
 
 	/* Check again if INVPCID interception if required */
 	svm_check_invpcid(svm);
@@ -4060,24 +4017,50 @@ static void svm_vcpu_after_set_cpuid(struct kvm_vcpu *vcpu)
 			vcpu->arch.reserved_gpa_bits &= ~(1UL << (best->ebx & 0x3f));
 	}
 
-	if (!kvm_vcpu_apicv_active(vcpu))
-		return;
+	if (kvm_vcpu_apicv_active(vcpu)) {
+		/*
+		 * AVIC does not work with an x2APIC mode guest. If the X2APIC feature
+		 * is exposed to the guest, disable AVIC.
+		 */
+		if (guest_cpuid_has(vcpu, X86_FEATURE_X2APIC))
+			kvm_request_apicv_update(vcpu->kvm, false,
+						 APICV_INHIBIT_REASON_X2APIC);
 
-	/*
-	 * AVIC does not work with an x2APIC mode guest. If the X2APIC feature
-	 * is exposed to the guest, disable AVIC.
-	 */
-	if (guest_cpuid_has(vcpu, X86_FEATURE_X2APIC))
-		kvm_request_apicv_update(vcpu->kvm, false,
-					 APICV_INHIBIT_REASON_X2APIC);
+		/*
+		 * Currently, AVIC does not work with nested virtualization.
+		 * So, we disable AVIC when cpuid for SVM is set in the L1 guest.
+		 */
+		if (nested && guest_cpuid_has(vcpu, X86_FEATURE_SVM))
+			kvm_request_apicv_update(vcpu->kvm, false,
+						 APICV_INHIBIT_REASON_NESTED);
+	}
 
-	/*
-	 * Currently, AVIC does not work with nested virtualization.
-	 * So, we disable AVIC when cpuid for SVM is set in the L1 guest.
-	 */
-	if (nested && guest_cpuid_has(vcpu, X86_FEATURE_SVM))
-		kvm_request_apicv_update(vcpu->kvm, false,
-					 APICV_INHIBIT_REASON_NESTED);
+	if (guest_cpuid_is_intel(vcpu)) {
+		/*
+		 * We must intercept SYSENTER_EIP and SYSENTER_ESP
+		 * accesses because the processor only stores 32 bits.
+		 * For the same reason we cannot use virtual VMLOAD/VMSAVE.
+		 */
+		svm_set_intercept(svm, INTERCEPT_VMLOAD);
+		svm_set_intercept(svm, INTERCEPT_VMSAVE);
+		svm->vmcb->control.virt_ext &= ~VIRTUAL_VMLOAD_VMSAVE_ENABLE_MASK;
+
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_EIP, 0, 0);
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_ESP, 0, 0);
+	} else {
+		/*
+		 * If hardware supports Virtual VMLOAD VMSAVE then enable it
+		 * in VMCB and clear intercepts to avoid #VMEXIT.
+		 */
+		if (vls) {
+			svm_clr_intercept(svm, INTERCEPT_VMLOAD);
+			svm_clr_intercept(svm, INTERCEPT_VMSAVE);
+			svm->vmcb->control.virt_ext |= VIRTUAL_VMLOAD_VMSAVE_ENABLE_MASK;
+		}
+		/* No need to intercept these MSRs */
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_EIP, 1, 1);
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_ESP, 1, 1);
+	}
 }
 
 static bool svm_has_wbinvd_exit(void)
@@ -4349,15 +4332,15 @@ static int svm_pre_leave_smm(struct kvm_vcpu *vcpu, const char *smstate)
 			if (!(saved_efer & EFER_SVME))
 				return 1;
 
-			if (kvm_vcpu_map(&svm->vcpu,
+			if (kvm_vcpu_map(vcpu,
 					 gpa_to_gfn(vmcb12_gpa), &map) == -EINVAL)
 				return 1;
 
 			if (svm_allocate_nested(svm))
 				return 1;
 
-			ret = enter_svm_guest_mode(svm, vmcb12_gpa, map.hva);
-			kvm_vcpu_unmap(&svm->vcpu, &map, true);
+			ret = enter_svm_guest_mode(vcpu, vmcb12_gpa, map.hva);
+			kvm_vcpu_unmap(vcpu, &map, true);
 		}
 	}
 
@@ -4612,6 +4595,8 @@ static struct kvm_x86_ops svm_x86_ops __initdata = {
 	.mem_enc_reg_region = svm_register_enc_region,
 	.mem_enc_unreg_region = svm_unregister_enc_region,
 
+	.vm_copy_enc_context_from = svm_vm_copy_asid_from,
+
 	.can_emulate_instruction = svm_can_emulate_instruction,
 
 	.apic_init_signal_blocked = svm_apic_init_signal_blocked,
diff --git a/arch/x86/kvm/svm/svm.h b/arch/x86/kvm/svm/svm.h
index 9806aaebc37f..84b3133c2251 100644
--- a/arch/x86/kvm/svm/svm.h
+++ b/arch/x86/kvm/svm/svm.h
@@ -23,12 +23,10 @@
 
 #define __sme_page_pa(x) __sme_set(page_to_pfn(x) << PAGE_SHIFT)
 
-static const u32 host_save_user_msrs[] = {
-	MSR_TSC_AUX,
-};
-#define NR_HOST_SAVE_USER_MSRS ARRAY_SIZE(host_save_user_msrs)
+#define	IOPM_SIZE PAGE_SIZE * 3
+#define	MSRPM_SIZE PAGE_SIZE * 2
 
-#define MAX_DIRECT_ACCESS_MSRS	18
+#define MAX_DIRECT_ACCESS_MSRS	20
 #define MSRPM_OFFSETS	16
 extern u32 msrpm_offsets[MSRPM_OFFSETS] __read_mostly;
 extern bool npt_enabled;
@@ -65,6 +63,7 @@ struct kvm_sev_info {
 	unsigned long pages_locked; /* Number of pages locked */
 	struct list_head regions_list;  /* List of registered regions */
 	u64 ap_jump_table;	/* SEV-ES AP Jump Table address */
+	struct kvm *enc_context_owner; /* Owner of copied encryption context */
 	struct misc_cg *misc_cg; /* For misc cgroup accounting */
 };
 
@@ -82,11 +81,19 @@ struct kvm_svm {
 
 struct kvm_vcpu;
 
+struct kvm_vmcb_info {
+	struct vmcb *ptr;
+	unsigned long pa;
+	int cpu;
+	uint64_t asid_generation;
+};
+
 struct svm_nested_state {
-	struct vmcb *hsave;
+	struct kvm_vmcb_info vmcb02;
 	u64 hsave_msr;
 	u64 vm_cr_msr;
 	u64 vmcb12_gpa;
+	u64 last_vmcb12_gpa;
 
 	/* These are the merged vectors */
 	u32 *msrpm;
@@ -103,21 +110,20 @@ struct svm_nested_state {
 
 struct vcpu_svm {
 	struct kvm_vcpu vcpu;
+	/* vmcb always points at current_vmcb->ptr, it's purely a shorthand. */
 	struct vmcb *vmcb;
-	unsigned long vmcb_pa;
+	struct kvm_vmcb_info vmcb01;
+	struct kvm_vmcb_info *current_vmcb;
 	struct svm_cpu_data *svm_data;
 	u32 asid;
-	uint64_t asid_generation;
-	uint64_t sysenter_esp;
-	uint64_t sysenter_eip;
+	u32 sysenter_esp_hi;
+	u32 sysenter_eip_hi;
 	uint64_t tsc_aux;
 
 	u64 msr_decfg;
 
 	u64 next_rip;
 
-	u64 host_user_msrs[NR_HOST_SAVE_USER_MSRS];
-
 	u64 spec_ctrl;
 	/*
 	 * Contains guest-controlled bits of VIRT_SPEC_CTRL, which will be
@@ -240,17 +246,14 @@ static inline void vmcb_mark_dirty(struct vmcb *vmcb, int bit)
 	vmcb->control.clean &= ~(1 << bit);
 }
 
-static inline struct vcpu_svm *to_svm(struct kvm_vcpu *vcpu)
+static inline bool vmcb_is_dirty(struct vmcb *vmcb, int bit)
 {
-	return container_of(vcpu, struct vcpu_svm, vcpu);
+        return !test_bit(bit, (unsigned long *)&vmcb->control.clean);
 }
 
-static inline struct vmcb *get_host_vmcb(struct vcpu_svm *svm)
+static inline struct vcpu_svm *to_svm(struct kvm_vcpu *vcpu)
 {
-	if (is_guest_mode(&svm->vcpu))
-		return svm->nested.hsave;
-	else
-		return svm->vmcb;
+	return container_of(vcpu, struct vcpu_svm, vcpu);
 }
 
 static inline void vmcb_set_intercept(struct vmcb_control_area *control, u32 bit)
@@ -273,7 +276,7 @@ static inline bool vmcb_is_intercept(struct vmcb_control_area *control, u32 bit)
 
 static inline void set_dr_intercepts(struct vcpu_svm *svm)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	if (!sev_es_guest(svm->vcpu.kvm)) {
 		vmcb_set_intercept(&vmcb->control, INTERCEPT_DR0_READ);
@@ -300,7 +303,7 @@ static inline void set_dr_intercepts(struct vcpu_svm *svm)
 
 static inline void clr_dr_intercepts(struct vcpu_svm *svm)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	vmcb->control.intercepts[INTERCEPT_DR] = 0;
 
@@ -315,7 +318,7 @@ static inline void clr_dr_intercepts(struct vcpu_svm *svm)
 
 static inline void set_exception_intercept(struct vcpu_svm *svm, u32 bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	WARN_ON_ONCE(bit >= 32);
 	vmcb_set_intercept(&vmcb->control, INTERCEPT_EXCEPTION_OFFSET + bit);
@@ -325,7 +328,7 @@ static inline void set_exception_intercept(struct vcpu_svm *svm, u32 bit)
 
 static inline void clr_exception_intercept(struct vcpu_svm *svm, u32 bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	WARN_ON_ONCE(bit >= 32);
 	vmcb_clr_intercept(&vmcb->control, INTERCEPT_EXCEPTION_OFFSET + bit);
@@ -335,7 +338,7 @@ static inline void clr_exception_intercept(struct vcpu_svm *svm, u32 bit)
 
 static inline void svm_set_intercept(struct vcpu_svm *svm, int bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	vmcb_set_intercept(&vmcb->control, bit);
 
@@ -344,7 +347,7 @@ static inline void svm_set_intercept(struct vcpu_svm *svm, int bit)
 
 static inline void svm_clr_intercept(struct vcpu_svm *svm, int bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	vmcb_clr_intercept(&vmcb->control, bit);
 
@@ -388,8 +391,6 @@ static inline bool gif_set(struct vcpu_svm *svm)
 /* svm.c */
 #define MSR_INVALID				0xffffffffU
 
-extern int sev;
-extern int sev_es;
 extern bool dump_invalid_vmcb;
 
 u32 svm_msrpm_offset(u32 msr);
@@ -406,7 +407,7 @@ bool svm_smi_blocked(struct kvm_vcpu *vcpu);
 bool svm_nmi_blocked(struct kvm_vcpu *vcpu);
 bool svm_interrupt_blocked(struct kvm_vcpu *vcpu);
 void svm_set_gif(struct vcpu_svm *svm, bool value);
-int svm_invoke_exit_handler(struct vcpu_svm *svm, u64 exit_code);
+int svm_invoke_exit_handler(struct kvm_vcpu *vcpu, u64 exit_code);
 void set_msr_interception(struct kvm_vcpu *vcpu, u32 *msrpm, u32 msr,
 			  int read, int write);
 
@@ -438,20 +439,30 @@ static inline bool nested_exit_on_nmi(struct vcpu_svm *svm)
 	return vmcb_is_intercept(&svm->nested.ctl, INTERCEPT_NMI);
 }
 
-int enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb_gpa,
-			 struct vmcb *nested_vmcb);
+int enter_svm_guest_mode(struct kvm_vcpu *vcpu, u64 vmcb_gpa, struct vmcb *vmcb12);
 void svm_leave_nested(struct vcpu_svm *svm);
 void svm_free_nested(struct vcpu_svm *svm);
 int svm_allocate_nested(struct vcpu_svm *svm);
-int nested_svm_vmrun(struct vcpu_svm *svm);
+int nested_svm_vmrun(struct kvm_vcpu *vcpu);
 void nested_svm_vmloadsave(struct vmcb *from_vmcb, struct vmcb *to_vmcb);
 int nested_svm_vmexit(struct vcpu_svm *svm);
+
+static inline int nested_svm_simple_vmexit(struct vcpu_svm *svm, u32 exit_code)
+{
+	svm->vmcb->control.exit_code   = exit_code;
+	svm->vmcb->control.exit_info_1 = 0;
+	svm->vmcb->control.exit_info_2 = 0;
+	return nested_svm_vmexit(svm);
+}
+
 int nested_svm_exit_handled(struct vcpu_svm *svm);
-int nested_svm_check_permissions(struct vcpu_svm *svm);
+int nested_svm_check_permissions(struct kvm_vcpu *vcpu);
 int nested_svm_check_exception(struct vcpu_svm *svm, unsigned nr,
 			       bool has_error_code, u32 error_code);
 int nested_svm_exit_special(struct vcpu_svm *svm);
-void sync_nested_vmcb_control(struct vcpu_svm *svm);
+void nested_sync_control_from_vmcb02(struct vcpu_svm *svm);
+void nested_vmcb02_compute_g_pat(struct vcpu_svm *svm);
+void svm_switch_vmcb(struct vcpu_svm *svm, struct kvm_vmcb_info *target_vmcb);
 
 extern struct kvm_x86_nested_ops svm_nested_ops;
 
@@ -492,8 +503,8 @@ void avic_vm_destroy(struct kvm *kvm);
 int avic_vm_init(struct kvm *kvm);
 void avic_init_vmcb(struct vcpu_svm *svm);
 void svm_toggle_avic_for_irq_window(struct kvm_vcpu *vcpu, bool activate);
-int avic_incomplete_ipi_interception(struct vcpu_svm *svm);
-int avic_unaccelerated_access_interception(struct vcpu_svm *svm);
+int avic_incomplete_ipi_interception(struct kvm_vcpu *vcpu);
+int avic_unaccelerated_access_interception(struct kvm_vcpu *vcpu);
 int avic_init_vcpu(struct vcpu_svm *svm);
 void avic_vcpu_load(struct kvm_vcpu *vcpu, int cpu);
 void avic_vcpu_put(struct kvm_vcpu *vcpu);
@@ -551,22 +562,20 @@ void svm_vcpu_unblocking(struct kvm_vcpu *vcpu);
 
 extern unsigned int max_sev_asid;
 
-static inline bool svm_sev_enabled(void)
-{
-	return IS_ENABLED(CONFIG_KVM_AMD_SEV) ? max_sev_asid : 0;
-}
-
 void sev_vm_destroy(struct kvm *kvm);
 int svm_mem_enc_op(struct kvm *kvm, void __user *argp);
 int svm_register_enc_region(struct kvm *kvm,
 			    struct kvm_enc_region *range);
 int svm_unregister_enc_region(struct kvm *kvm,
 			      struct kvm_enc_region *range);
+int svm_vm_copy_asid_from(struct kvm *kvm, unsigned int source_fd);
 void pre_sev_run(struct vcpu_svm *svm, int cpu);
+void __init sev_set_cpu_caps(void);
 void __init sev_hardware_setup(void);
 void sev_hardware_teardown(void);
+int sev_cpu_init(struct svm_cpu_data *sd);
 void sev_free_vcpu(struct kvm_vcpu *vcpu);
-int sev_handle_vmgexit(struct vcpu_svm *svm);
+int sev_handle_vmgexit(struct kvm_vcpu *vcpu);
 int sev_es_string_io(struct vcpu_svm *svm, int size, unsigned int port, int in);
 void sev_es_init_vmcb(struct vcpu_svm *svm);
 void sev_es_create_vcpu(struct vcpu_svm *svm);
diff --git a/arch/x86/kvm/svm/vmenter.S b/arch/x86/kvm/svm/vmenter.S
index 6feb8c08f45a..4fa17df123cd 100644
--- a/arch/x86/kvm/svm/vmenter.S
+++ b/arch/x86/kvm/svm/vmenter.S
@@ -79,28 +79,10 @@ SYM_FUNC_START(__svm_vcpu_run)
 
 	/* Enter guest mode */
 	sti
-1:	vmload %_ASM_AX
-	jmp 3f
-2:	cmpb $0, kvm_rebooting
-	jne 3f
-	ud2
-	_ASM_EXTABLE(1b, 2b)
 
-3:	vmrun %_ASM_AX
-	jmp 5f
-4:	cmpb $0, kvm_rebooting
-	jne 5f
-	ud2
-	_ASM_EXTABLE(3b, 4b)
+1:	vmrun %_ASM_AX
 
-5:	vmsave %_ASM_AX
-	jmp 7f
-6:	cmpb $0, kvm_rebooting
-	jne 7f
-	ud2
-	_ASM_EXTABLE(5b, 6b)
-7:
-	cli
+2:	cli
 
 #ifdef CONFIG_RETPOLINE
 	/* IMPORTANT: Stuff the RSB immediately after VM-Exit, before RET! */
@@ -167,6 +149,13 @@ SYM_FUNC_START(__svm_vcpu_run)
 #endif
 	pop %_ASM_BP
 	ret
+
+3:	cmpb $0, kvm_rebooting
+	jne 2b
+	ud2
+
+	_ASM_EXTABLE(1b, 3b)
+
 SYM_FUNC_END(__svm_vcpu_run)
 
 /**
@@ -186,18 +175,15 @@ SYM_FUNC_START(__svm_sev_es_vcpu_run)
 #endif
 	push %_ASM_BX
 
-	/* Enter guest mode */
+	/* Move @vmcb to RAX. */
 	mov %_ASM_ARG1, %_ASM_AX
+
+	/* Enter guest mode */
 	sti
 
 1:	vmrun %_ASM_AX
-	jmp 3f
-2:	cmpb $0, kvm_rebooting
-	jne 3f
-	ud2
-	_ASM_EXTABLE(1b, 2b)
 
-3:	cli
+2:	cli
 
 #ifdef CONFIG_RETPOLINE
 	/* IMPORTANT: Stuff the RSB immediately after VM-Exit, before RET! */
@@ -217,4 +203,11 @@ SYM_FUNC_START(__svm_sev_es_vcpu_run)
 #endif
 	pop %_ASM_BP
 	ret
+
+3:	cmpb $0, kvm_rebooting
+	jne 2b
+	ud2
+
+	_ASM_EXTABLE(1b, 3b)
+
 SYM_FUNC_END(__svm_sev_es_vcpu_run)
diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index 1e069aac7410..bced76637823 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -11,6 +11,7 @@
 #include "mmu.h"
 #include "nested.h"
 #include "pmu.h"
+#include "sgx.h"
 #include "trace.h"
 #include "vmx.h"
 #include "x86.h"
@@ -21,13 +22,7 @@ module_param_named(enable_shadow_vmcs, enable_shadow_vmcs, bool, S_IRUGO);
 static bool __read_mostly nested_early_check = 0;
 module_param(nested_early_check, bool, S_IRUGO);
 
-#define CC(consistency_check)						\
-({									\
-	bool failed = (consistency_check);				\
-	if (failed)							\
-		trace_kvm_nested_vmenter_failed(#consistency_check, 0);	\
-	failed;								\
-})
+#define CC KVM_NESTED_VMENTER_CONSISTENCY_CHECK
 
 /*
  * Hyper-V requires all of these, so mark them as supported even though
@@ -619,6 +614,7 @@ static inline bool nested_vmx_prepare_msr_bitmap(struct kvm_vcpu *vcpu,
 	}
 
 	/* KVM unconditionally exposes the FS/GS base MSRs to L1. */
+#ifdef CONFIG_X86_64
 	nested_vmx_disable_intercept_for_msr(msr_bitmap_l1, msr_bitmap_l0,
 					     MSR_FS_BASE, MSR_TYPE_RW);
 
@@ -627,6 +623,7 @@ static inline bool nested_vmx_prepare_msr_bitmap(struct kvm_vcpu *vcpu,
 
 	nested_vmx_disable_intercept_for_msr(msr_bitmap_l1, msr_bitmap_l0,
 					     MSR_KERNEL_GS_BASE, MSR_TYPE_RW);
+#endif
 
 	/*
 	 * Checking the L0->L1 bitmap is trying to verify two things:
@@ -2306,6 +2303,9 @@ static void prepare_vmcs02_early(struct vcpu_vmx *vmx, struct vmcs12 *vmcs12)
 		if (!nested_cpu_has2(vmcs12, SECONDARY_EXEC_UNRESTRICTED_GUEST))
 		    exec_control &= ~SECONDARY_EXEC_UNRESTRICTED_GUEST;
 
+		if (exec_control & SECONDARY_EXEC_ENCLS_EXITING)
+			vmx_write_encls_bitmap(&vmx->vcpu, vmcs12);
+
 		secondary_exec_controls_set(vmx, exec_control);
 	}
 
@@ -3453,6 +3453,8 @@ static int nested_vmx_run(struct kvm_vcpu *vcpu, bool launch)
 	u32 interrupt_shadow = vmx_get_interrupt_shadow(vcpu);
 	enum nested_evmptrld_status evmptrld_status;
 
+	++vcpu->stat.nested_run;
+
 	if (!nested_vmx_check_permission(vcpu))
 		return 1;
 
@@ -3810,9 +3812,15 @@ static int vmx_check_nested_events(struct kvm_vcpu *vcpu)
 
 	/*
 	 * Process any exceptions that are not debug traps before MTF.
+	 *
+	 * Note that only a pending nested run can block a pending exception.
+	 * Otherwise an injected NMI/interrupt should either be
+	 * lost or delivered to the nested hypervisor in the IDT_VECTORING_INFO,
+	 * while delivering the pending exception.
 	 */
+
 	if (vcpu->arch.exception.pending && !vmx_pending_dbg_trap(vcpu)) {
-		if (block_nested_events)
+		if (vmx->nested.nested_run_pending)
 			return -EBUSY;
 		if (!nested_vmx_check_exception(vcpu, &exit_qual))
 			goto no_vmexit;
@@ -3829,7 +3837,7 @@ static int vmx_check_nested_events(struct kvm_vcpu *vcpu)
 	}
 
 	if (vcpu->arch.exception.pending) {
-		if (block_nested_events)
+		if (vmx->nested.nested_run_pending)
 			return -EBUSY;
 		if (!nested_vmx_check_exception(vcpu, &exit_qual))
 			goto no_vmexit;
@@ -4105,6 +4113,8 @@ static void prepare_vmcs12(struct kvm_vcpu *vcpu, struct vmcs12 *vmcs12,
 {
 	/* update exit information fields: */
 	vmcs12->vm_exit_reason = vm_exit_reason;
+	if (to_vmx(vcpu)->exit_reason.enclave_mode)
+		vmcs12->vm_exit_reason |= VMX_EXIT_REASONS_SGX_ENCLAVE_MODE;
 	vmcs12->exit_qualification = exit_qualification;
 	vmcs12->vm_exit_intr_info = exit_intr_info;
 
@@ -4422,6 +4432,9 @@ void nested_vmx_vmexit(struct kvm_vcpu *vcpu, u32 vm_exit_reason,
 	/* trying to cancel vmlaunch/vmresume is a bug */
 	WARN_ON_ONCE(vmx->nested.nested_run_pending);
 
+	/* Similarly, triple faults in L2 should never escape. */
+	WARN_ON_ONCE(kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu));
+
 	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 
 	/* Service the TLB flush request for L2 before switching to L1. */
@@ -4558,6 +4571,11 @@ void nested_vmx_vmexit(struct kvm_vcpu *vcpu, u32 vm_exit_reason,
 	vmx->fail = 0;
 }
 
+static void nested_vmx_triple_fault(struct kvm_vcpu *vcpu)
+{
+	nested_vmx_vmexit(vcpu, EXIT_REASON_TRIPLE_FAULT, 0, 0);
+}
+
 /*
  * Decode the memory-address operand of a vmx instruction, as recorded on an
  * exit caused by such an instruction (run by a guest hypervisor).
@@ -5005,7 +5023,7 @@ static int handle_vmread(struct kvm_vcpu *vcpu)
 		return nested_vmx_failInvalid(vcpu);
 
 	/* Decode instruction info and find the field to read */
-	field = kvm_register_readl(vcpu, (((instr_info) >> 28) & 0xf));
+	field = kvm_register_read(vcpu, (((instr_info) >> 28) & 0xf));
 
 	offset = vmcs_field_to_offset(field);
 	if (offset < 0)
@@ -5023,7 +5041,7 @@ static int handle_vmread(struct kvm_vcpu *vcpu)
 	 * on the guest's mode (32 or 64 bit), not on the given field's length.
 	 */
 	if (instr_info & BIT(10)) {
-		kvm_register_writel(vcpu, (((instr_info) >> 3) & 0xf), value);
+		kvm_register_write(vcpu, (((instr_info) >> 3) & 0xf), value);
 	} else {
 		len = is_64_bit_mode(vcpu) ? 8 : 4;
 		if (get_vmx_mem_address(vcpu, exit_qualification,
@@ -5097,7 +5115,7 @@ static int handle_vmwrite(struct kvm_vcpu *vcpu)
 		return nested_vmx_failInvalid(vcpu);
 
 	if (instr_info & BIT(10))
-		value = kvm_register_readl(vcpu, (((instr_info) >> 3) & 0xf));
+		value = kvm_register_read(vcpu, (((instr_info) >> 3) & 0xf));
 	else {
 		len = is_64_bit_mode(vcpu) ? 8 : 4;
 		if (get_vmx_mem_address(vcpu, exit_qualification,
@@ -5108,7 +5126,7 @@ static int handle_vmwrite(struct kvm_vcpu *vcpu)
 			return kvm_handle_memory_failure(vcpu, r, &e);
 	}
 
-	field = kvm_register_readl(vcpu, (((instr_info) >> 28) & 0xf));
+	field = kvm_register_read(vcpu, (((instr_info) >> 28) & 0xf));
 
 	offset = vmcs_field_to_offset(field);
 	if (offset < 0)
@@ -5305,7 +5323,7 @@ static int handle_invept(struct kvm_vcpu *vcpu)
 		return 1;
 
 	vmx_instruction_info = vmcs_read32(VMX_INSTRUCTION_INFO);
-	type = kvm_register_readl(vcpu, (vmx_instruction_info >> 28) & 0xf);
+	type = kvm_register_read(vcpu, (vmx_instruction_info >> 28) & 0xf);
 
 	types = (vmx->nested.msrs.ept_caps >> VMX_EPT_EXTENT_SHIFT) & 6;
 
@@ -5385,7 +5403,7 @@ static int handle_invvpid(struct kvm_vcpu *vcpu)
 		return 1;
 
 	vmx_instruction_info = vmcs_read32(VMX_INSTRUCTION_INFO);
-	type = kvm_register_readl(vcpu, (vmx_instruction_info >> 28) & 0xf);
+	type = kvm_register_read(vcpu, (vmx_instruction_info >> 28) & 0xf);
 
 	types = (vmx->nested.msrs.vpid_caps &
 			VMX_VPID_EXTENT_SUPPORTED_MASK) >> 8;
@@ -5479,16 +5497,11 @@ static int nested_vmx_eptp_switching(struct kvm_vcpu *vcpu,
 		if (!nested_vmx_check_eptp(vcpu, new_eptp))
 			return 1;
 
-		kvm_mmu_unload(vcpu);
 		mmu->ept_ad = accessed_dirty;
 		mmu->mmu_role.base.ad_disabled = !accessed_dirty;
 		vmcs12->ept_pointer = new_eptp;
-		/*
-		 * TODO: Check what's the correct approach in case
-		 * mmu reload fails. Currently, we just let the next
-		 * reload potentially fail
-		 */
-		kvm_mmu_reload(vcpu);
+
+		kvm_make_request(KVM_REQ_MMU_RELOAD, vcpu);
 	}
 
 	return 0;
@@ -5646,7 +5659,7 @@ static bool nested_vmx_exit_handled_cr(struct kvm_vcpu *vcpu,
 	switch ((exit_qualification >> 4) & 3) {
 	case 0: /* mov to cr */
 		reg = (exit_qualification >> 8) & 15;
-		val = kvm_register_readl(vcpu, reg);
+		val = kvm_register_read(vcpu, reg);
 		switch (cr) {
 		case 0:
 			if (vmcs12->cr0_guest_host_mask &
@@ -5705,6 +5718,21 @@ static bool nested_vmx_exit_handled_cr(struct kvm_vcpu *vcpu,
 	return false;
 }
 
+static bool nested_vmx_exit_handled_encls(struct kvm_vcpu *vcpu,
+					  struct vmcs12 *vmcs12)
+{
+	u32 encls_leaf;
+
+	if (!guest_cpuid_has(vcpu, X86_FEATURE_SGX) ||
+	    !nested_cpu_has2(vmcs12, SECONDARY_EXEC_ENCLS_EXITING))
+		return false;
+
+	encls_leaf = kvm_rax_read(vcpu);
+	if (encls_leaf > 62)
+		encls_leaf = 63;
+	return vmcs12->encls_exiting_bitmap & BIT_ULL(encls_leaf);
+}
+
 static bool nested_vmx_exit_handled_vmcs_access(struct kvm_vcpu *vcpu,
 	struct vmcs12 *vmcs12, gpa_t bitmap)
 {
@@ -5801,9 +5829,6 @@ static bool nested_vmx_l0_wants_exit(struct kvm_vcpu *vcpu,
 	case EXIT_REASON_VMFUNC:
 		/* VM functions are emulated through L2->L0 vmexits. */
 		return true;
-	case EXIT_REASON_ENCLS:
-		/* SGX is never exposed to L1 */
-		return true;
 	default:
 		break;
 	}
@@ -5927,6 +5952,8 @@ static bool nested_vmx_l1_wants_exit(struct kvm_vcpu *vcpu,
 	case EXIT_REASON_TPAUSE:
 		return nested_cpu_has2(vmcs12,
 			SECONDARY_EXEC_ENABLE_USR_WAIT_PAUSE);
+	case EXIT_REASON_ENCLS:
+		return nested_vmx_exit_handled_encls(vcpu, vmcs12);
 	default:
 		return true;
 	}
@@ -6502,6 +6529,9 @@ void nested_vmx_setup_ctls_msrs(struct nested_vmx_msrs *msrs, u32 ept_caps)
 		msrs->secondary_ctls_high |=
 			SECONDARY_EXEC_VIRTUALIZE_APIC_ACCESSES;
 
+	if (enable_sgx)
+		msrs->secondary_ctls_high |= SECONDARY_EXEC_ENCLS_EXITING;
+
 	/* miscellaneous data */
 	rdmsr(MSR_IA32_VMX_MISC,
 		msrs->misc_low,
@@ -6599,6 +6629,7 @@ __init int nested_vmx_hardware_setup(int (*exit_handlers[])(struct kvm_vcpu *))
 struct kvm_x86_nested_ops vmx_nested_ops = {
 	.check_events = vmx_check_nested_events,
 	.hv_timer_pending = nested_vmx_preemption_timer_pending,
+	.triple_fault = nested_vmx_triple_fault,
 	.get_state = vmx_get_nested_state,
 	.set_state = vmx_set_nested_state,
 	.get_nested_state_pages = vmx_get_nested_state_pages,
diff --git a/arch/x86/kvm/vmx/nested.h b/arch/x86/kvm/vmx/nested.h
index 197148d76b8f..184418baeb3c 100644
--- a/arch/x86/kvm/vmx/nested.h
+++ b/arch/x86/kvm/vmx/nested.h
@@ -244,6 +244,11 @@ static inline bool nested_exit_on_intr(struct kvm_vcpu *vcpu)
 		PIN_BASED_EXT_INTR_MASK;
 }
 
+static inline bool nested_cpu_has_encls_exit(struct vmcs12 *vmcs12)
+{
+	return nested_cpu_has2(vmcs12, SECONDARY_EXEC_ENCLS_EXITING);
+}
+
 /*
  * if fixed0[i] == 1: val[i] must be 1
  * if fixed1[i] == 0: val[i] must be 0
diff --git a/arch/x86/kvm/vmx/sgx.c b/arch/x86/kvm/vmx/sgx.c
new file mode 100644
index 000000000000..6693ebdc0770
--- /dev/null
+++ b/arch/x86/kvm/vmx/sgx.c
@@ -0,0 +1,502 @@
+// SPDX-License-Identifier: GPL-2.0
+/*  Copyright(c) 2021 Intel Corporation. */
+
+#include <asm/sgx.h>
+
+#include "cpuid.h"
+#include "kvm_cache_regs.h"
+#include "nested.h"
+#include "sgx.h"
+#include "vmx.h"
+#include "x86.h"
+
+bool __read_mostly enable_sgx = 1;
+module_param_named(sgx, enable_sgx, bool, 0444);
+
+/* Initial value of guest's virtual SGX_LEPUBKEYHASHn MSRs */
+static u64 sgx_pubkey_hash[4] __ro_after_init;
+
+/*
+ * ENCLS's memory operands use a fixed segment (DS) and a fixed
+ * address size based on the mode.  Related prefixes are ignored.
+ */
+static int sgx_get_encls_gva(struct kvm_vcpu *vcpu, unsigned long offset,
+			     int size, int alignment, gva_t *gva)
+{
+	struct kvm_segment s;
+	bool fault;
+
+	/* Skip vmcs.GUEST_DS retrieval for 64-bit mode to avoid VMREADs. */
+	*gva = offset;
+	if (!is_long_mode(vcpu)) {
+		vmx_get_segment(vcpu, &s, VCPU_SREG_DS);
+		*gva += s.base;
+	}
+
+	if (!IS_ALIGNED(*gva, alignment)) {
+		fault = true;
+	} else if (likely(is_long_mode(vcpu))) {
+		fault = is_noncanonical_address(*gva, vcpu);
+	} else {
+		*gva &= 0xffffffff;
+		fault = (s.unusable) ||
+			(s.type != 2 && s.type != 3) ||
+			(*gva > s.limit) ||
+			((s.base != 0 || s.limit != 0xffffffff) &&
+			(((u64)*gva + size - 1) > s.limit + 1));
+	}
+	if (fault)
+		kvm_inject_gp(vcpu, 0);
+	return fault ? -EINVAL : 0;
+}
+
+static void sgx_handle_emulation_failure(struct kvm_vcpu *vcpu, u64 addr,
+					 unsigned int size)
+{
+	vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
+	vcpu->run->internal.suberror = KVM_INTERNAL_ERROR_EMULATION;
+	vcpu->run->internal.ndata = 2;
+	vcpu->run->internal.data[0] = addr;
+	vcpu->run->internal.data[1] = size;
+}
+
+static int sgx_read_hva(struct kvm_vcpu *vcpu, unsigned long hva, void *data,
+			unsigned int size)
+{
+	if (__copy_from_user(data, (void __user *)hva, size)) {
+		sgx_handle_emulation_failure(vcpu, hva, size);
+		return -EFAULT;
+	}
+
+	return 0;
+}
+
+static int sgx_gva_to_gpa(struct kvm_vcpu *vcpu, gva_t gva, bool write,
+			  gpa_t *gpa)
+{
+	struct x86_exception ex;
+
+	if (write)
+		*gpa = kvm_mmu_gva_to_gpa_write(vcpu, gva, &ex);
+	else
+		*gpa = kvm_mmu_gva_to_gpa_read(vcpu, gva, &ex);
+
+	if (*gpa == UNMAPPED_GVA) {
+		kvm_inject_emulated_page_fault(vcpu, &ex);
+		return -EFAULT;
+	}
+
+	return 0;
+}
+
+static int sgx_gpa_to_hva(struct kvm_vcpu *vcpu, gpa_t gpa, unsigned long *hva)
+{
+	*hva = kvm_vcpu_gfn_to_hva(vcpu, PFN_DOWN(gpa));
+	if (kvm_is_error_hva(*hva)) {
+		sgx_handle_emulation_failure(vcpu, gpa, 1);
+		return -EFAULT;
+	}
+
+	*hva |= gpa & ~PAGE_MASK;
+
+	return 0;
+}
+
+static int sgx_inject_fault(struct kvm_vcpu *vcpu, gva_t gva, int trapnr)
+{
+	struct x86_exception ex;
+
+	/*
+	 * A non-EPCM #PF indicates a bad userspace HVA.  This *should* check
+	 * for PFEC.SGX and not assume any #PF on SGX2 originated in the EPC,
+	 * but the error code isn't (yet) plumbed through the ENCLS helpers.
+	 */
+	if (trapnr == PF_VECTOR && !boot_cpu_has(X86_FEATURE_SGX2)) {
+		vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
+		vcpu->run->internal.suberror = KVM_INTERNAL_ERROR_EMULATION;
+		vcpu->run->internal.ndata = 0;
+		return 0;
+	}
+
+	/*
+	 * If the guest thinks it's running on SGX2 hardware, inject an SGX
+	 * #PF if the fault matches an EPCM fault signature (#GP on SGX1,
+	 * #PF on SGX2).  The assumption is that EPCM faults are much more
+	 * likely than a bad userspace address.
+	 */
+	if ((trapnr == PF_VECTOR || !boot_cpu_has(X86_FEATURE_SGX2)) &&
+	    guest_cpuid_has(vcpu, X86_FEATURE_SGX2)) {
+		memset(&ex, 0, sizeof(ex));
+		ex.vector = PF_VECTOR;
+		ex.error_code = PFERR_PRESENT_MASK | PFERR_WRITE_MASK |
+				PFERR_SGX_MASK;
+		ex.address = gva;
+		ex.error_code_valid = true;
+		ex.nested_page_fault = false;
+		kvm_inject_page_fault(vcpu, &ex);
+	} else {
+		kvm_inject_gp(vcpu, 0);
+	}
+	return 1;
+}
+
+static int __handle_encls_ecreate(struct kvm_vcpu *vcpu,
+				  struct sgx_pageinfo *pageinfo,
+				  unsigned long secs_hva,
+				  gva_t secs_gva)
+{
+	struct sgx_secs *contents = (struct sgx_secs *)pageinfo->contents;
+	struct kvm_cpuid_entry2 *sgx_12_0, *sgx_12_1;
+	u64 attributes, xfrm, size;
+	u32 miscselect;
+	u8 max_size_log2;
+	int trapnr, ret;
+
+	sgx_12_0 = kvm_find_cpuid_entry(vcpu, 0x12, 0);
+	sgx_12_1 = kvm_find_cpuid_entry(vcpu, 0x12, 1);
+	if (!sgx_12_0 || !sgx_12_1) {
+		vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
+		vcpu->run->internal.suberror = KVM_INTERNAL_ERROR_EMULATION;
+		vcpu->run->internal.ndata = 0;
+		return 0;
+	}
+
+	miscselect = contents->miscselect;
+	attributes = contents->attributes;
+	xfrm = contents->xfrm;
+	size = contents->size;
+
+	/* Enforce restriction of access to the PROVISIONKEY. */
+	if (!vcpu->kvm->arch.sgx_provisioning_allowed &&
+	    (attributes & SGX_ATTR_PROVISIONKEY)) {
+		if (sgx_12_1->eax & SGX_ATTR_PROVISIONKEY)
+			pr_warn_once("KVM: SGX PROVISIONKEY advertised but not allowed\n");
+		kvm_inject_gp(vcpu, 0);
+		return 1;
+	}
+
+	/* Enforce CPUID restrictions on MISCSELECT, ATTRIBUTES and XFRM. */
+	if ((u32)miscselect & ~sgx_12_0->ebx ||
+	    (u32)attributes & ~sgx_12_1->eax ||
+	    (u32)(attributes >> 32) & ~sgx_12_1->ebx ||
+	    (u32)xfrm & ~sgx_12_1->ecx ||
+	    (u32)(xfrm >> 32) & ~sgx_12_1->edx) {
+		kvm_inject_gp(vcpu, 0);
+		return 1;
+	}
+
+	/* Enforce CPUID restriction on max enclave size. */
+	max_size_log2 = (attributes & SGX_ATTR_MODE64BIT) ? sgx_12_0->edx >> 8 :
+							    sgx_12_0->edx;
+	if (size >= BIT_ULL(max_size_log2))
+		kvm_inject_gp(vcpu, 0);
+
+	/*
+	 * sgx_virt_ecreate() returns:
+	 *  1) 0:	ECREATE was successful
+	 *  2) -EFAULT:	ECREATE was run but faulted, and trapnr was set to the
+	 *		exception number.
+	 *  3) -EINVAL:	access_ok() on @secs_hva failed. This should never
+	 *		happen as KVM checks host addresses at memslot creation.
+	 *		sgx_virt_ecreate() has already warned in this case.
+	 */
+	ret = sgx_virt_ecreate(pageinfo, (void __user *)secs_hva, &trapnr);
+	if (!ret)
+		return kvm_skip_emulated_instruction(vcpu);
+	if (ret == -EFAULT)
+		return sgx_inject_fault(vcpu, secs_gva, trapnr);
+
+	return ret;
+}
+
+static int handle_encls_ecreate(struct kvm_vcpu *vcpu)
+{
+	gva_t pageinfo_gva, secs_gva;
+	gva_t metadata_gva, contents_gva;
+	gpa_t metadata_gpa, contents_gpa, secs_gpa;
+	unsigned long metadata_hva, contents_hva, secs_hva;
+	struct sgx_pageinfo pageinfo;
+	struct sgx_secs *contents;
+	struct x86_exception ex;
+	int r;
+
+	if (sgx_get_encls_gva(vcpu, kvm_rbx_read(vcpu), 32, 32, &pageinfo_gva) ||
+	    sgx_get_encls_gva(vcpu, kvm_rcx_read(vcpu), 4096, 4096, &secs_gva))
+		return 1;
+
+	/*
+	 * Copy the PAGEINFO to local memory, its pointers need to be
+	 * translated, i.e. we need to do a deep copy/translate.
+	 */
+	r = kvm_read_guest_virt(vcpu, pageinfo_gva, &pageinfo,
+				sizeof(pageinfo), &ex);
+	if (r == X86EMUL_PROPAGATE_FAULT) {
+		kvm_inject_emulated_page_fault(vcpu, &ex);
+		return 1;
+	} else if (r != X86EMUL_CONTINUE) {
+		sgx_handle_emulation_failure(vcpu, pageinfo_gva,
+					     sizeof(pageinfo));
+		return 0;
+	}
+
+	if (sgx_get_encls_gva(vcpu, pageinfo.metadata, 64, 64, &metadata_gva) ||
+	    sgx_get_encls_gva(vcpu, pageinfo.contents, 4096, 4096,
+			      &contents_gva))
+		return 1;
+
+	/*
+	 * Translate the SECINFO, SOURCE and SECS pointers from GVA to GPA.
+	 * Resume the guest on failure to inject a #PF.
+	 */
+	if (sgx_gva_to_gpa(vcpu, metadata_gva, false, &metadata_gpa) ||
+	    sgx_gva_to_gpa(vcpu, contents_gva, false, &contents_gpa) ||
+	    sgx_gva_to_gpa(vcpu, secs_gva, true, &secs_gpa))
+		return 1;
+
+	/*
+	 * ...and then to HVA.  The order of accesses isn't architectural, i.e.
+	 * KVM doesn't have to fully process one address at a time.  Exit to
+	 * userspace if a GPA is invalid.
+	 */
+	if (sgx_gpa_to_hva(vcpu, metadata_gpa, &metadata_hva) ||
+	    sgx_gpa_to_hva(vcpu, contents_gpa, &contents_hva) ||
+	    sgx_gpa_to_hva(vcpu, secs_gpa, &secs_hva))
+		return 0;
+
+	/*
+	 * Copy contents into kernel memory to prevent TOCTOU attack. E.g. the
+	 * guest could do ECREATE w/ SECS.SGX_ATTR_PROVISIONKEY=0, and
+	 * simultaneously set SGX_ATTR_PROVISIONKEY to bypass the check to
+	 * enforce restriction of access to the PROVISIONKEY.
+	 */
+	contents = (struct sgx_secs *)__get_free_page(GFP_KERNEL_ACCOUNT);
+	if (!contents)
+		return -ENOMEM;
+
+	/* Exit to userspace if copying from a host userspace address fails. */
+	if (sgx_read_hva(vcpu, contents_hva, (void *)contents, PAGE_SIZE)) {
+		free_page((unsigned long)contents);
+		return 0;
+	}
+
+	pageinfo.metadata = metadata_hva;
+	pageinfo.contents = (u64)contents;
+
+	r = __handle_encls_ecreate(vcpu, &pageinfo, secs_hva, secs_gva);
+
+	free_page((unsigned long)contents);
+
+	return r;
+}
+
+static int handle_encls_einit(struct kvm_vcpu *vcpu)
+{
+	unsigned long sig_hva, secs_hva, token_hva, rflags;
+	struct vcpu_vmx *vmx = to_vmx(vcpu);
+	gva_t sig_gva, secs_gva, token_gva;
+	gpa_t sig_gpa, secs_gpa, token_gpa;
+	int ret, trapnr;
+
+	if (sgx_get_encls_gva(vcpu, kvm_rbx_read(vcpu), 1808, 4096, &sig_gva) ||
+	    sgx_get_encls_gva(vcpu, kvm_rcx_read(vcpu), 4096, 4096, &secs_gva) ||
+	    sgx_get_encls_gva(vcpu, kvm_rdx_read(vcpu), 304, 512, &token_gva))
+		return 1;
+
+	/*
+	 * Translate the SIGSTRUCT, SECS and TOKEN pointers from GVA to GPA.
+	 * Resume the guest on failure to inject a #PF.
+	 */
+	if (sgx_gva_to_gpa(vcpu, sig_gva, false, &sig_gpa) ||
+	    sgx_gva_to_gpa(vcpu, secs_gva, true, &secs_gpa) ||
+	    sgx_gva_to_gpa(vcpu, token_gva, false, &token_gpa))
+		return 1;
+
+	/*
+	 * ...and then to HVA.  The order of accesses isn't architectural, i.e.
+	 * KVM doesn't have to fully process one address at a time.  Exit to
+	 * userspace if a GPA is invalid.  Note, all structures are aligned and
+	 * cannot split pages.
+	 */
+	if (sgx_gpa_to_hva(vcpu, sig_gpa, &sig_hva) ||
+	    sgx_gpa_to_hva(vcpu, secs_gpa, &secs_hva) ||
+	    sgx_gpa_to_hva(vcpu, token_gpa, &token_hva))
+		return 0;
+
+	ret = sgx_virt_einit((void __user *)sig_hva, (void __user *)token_hva,
+			     (void __user *)secs_hva,
+			     vmx->msr_ia32_sgxlepubkeyhash, &trapnr);
+
+	if (ret == -EFAULT)
+		return sgx_inject_fault(vcpu, secs_gva, trapnr);
+
+	/*
+	 * sgx_virt_einit() returns -EINVAL when access_ok() fails on @sig_hva,
+	 * @token_hva or @secs_hva. This should never happen as KVM checks host
+	 * addresses at memslot creation. sgx_virt_einit() has already warned
+	 * in this case, so just return.
+	 */
+	if (ret < 0)
+		return ret;
+
+	rflags = vmx_get_rflags(vcpu) & ~(X86_EFLAGS_CF | X86_EFLAGS_PF |
+					  X86_EFLAGS_AF | X86_EFLAGS_SF |
+					  X86_EFLAGS_OF);
+	if (ret)
+		rflags |= X86_EFLAGS_ZF;
+	else
+		rflags &= ~X86_EFLAGS_ZF;
+	vmx_set_rflags(vcpu, rflags);
+
+	kvm_rax_write(vcpu, ret);
+	return kvm_skip_emulated_instruction(vcpu);
+}
+
+static inline bool encls_leaf_enabled_in_guest(struct kvm_vcpu *vcpu, u32 leaf)
+{
+	if (!enable_sgx || !guest_cpuid_has(vcpu, X86_FEATURE_SGX))
+		return false;
+
+	if (leaf >= ECREATE && leaf <= ETRACK)
+		return guest_cpuid_has(vcpu, X86_FEATURE_SGX1);
+
+	if (leaf >= EAUG && leaf <= EMODT)
+		return guest_cpuid_has(vcpu, X86_FEATURE_SGX2);
+
+	return false;
+}
+
+static inline bool sgx_enabled_in_guest_bios(struct kvm_vcpu *vcpu)
+{
+	const u64 bits = FEAT_CTL_SGX_ENABLED | FEAT_CTL_LOCKED;
+
+	return (to_vmx(vcpu)->msr_ia32_feature_control & bits) == bits;
+}
+
+int handle_encls(struct kvm_vcpu *vcpu)
+{
+	u32 leaf = (u32)kvm_rax_read(vcpu);
+
+	if (!encls_leaf_enabled_in_guest(vcpu, leaf)) {
+		kvm_queue_exception(vcpu, UD_VECTOR);
+	} else if (!sgx_enabled_in_guest_bios(vcpu)) {
+		kvm_inject_gp(vcpu, 0);
+	} else {
+		if (leaf == ECREATE)
+			return handle_encls_ecreate(vcpu);
+		if (leaf == EINIT)
+			return handle_encls_einit(vcpu);
+		WARN(1, "KVM: unexpected exit on ENCLS[%u]", leaf);
+		vcpu->run->exit_reason = KVM_EXIT_UNKNOWN;
+		vcpu->run->hw.hardware_exit_reason = EXIT_REASON_ENCLS;
+		return 0;
+	}
+	return 1;
+}
+
+void setup_default_sgx_lepubkeyhash(void)
+{
+	/*
+	 * Use Intel's default value for Skylake hardware if Launch Control is
+	 * not supported, i.e. Intel's hash is hardcoded into silicon, or if
+	 * Launch Control is supported and enabled, i.e. mimic the reset value
+	 * and let the guest write the MSRs at will.  If Launch Control is
+	 * supported but disabled, then use the current MSR values as the hash
+	 * MSRs exist but are read-only (locked and not writable).
+	 */
+	if (!enable_sgx || boot_cpu_has(X86_FEATURE_SGX_LC) ||
+	    rdmsrl_safe(MSR_IA32_SGXLEPUBKEYHASH0, &sgx_pubkey_hash[0])) {
+		sgx_pubkey_hash[0] = 0xa6053e051270b7acULL;
+		sgx_pubkey_hash[1] = 0x6cfbe8ba8b3b413dULL;
+		sgx_pubkey_hash[2] = 0xc4916d99f2b3735dULL;
+		sgx_pubkey_hash[3] = 0xd4f8c05909f9bb3bULL;
+	} else {
+		/* MSR_IA32_SGXLEPUBKEYHASH0 is read above */
+		rdmsrl(MSR_IA32_SGXLEPUBKEYHASH1, sgx_pubkey_hash[1]);
+		rdmsrl(MSR_IA32_SGXLEPUBKEYHASH2, sgx_pubkey_hash[2]);
+		rdmsrl(MSR_IA32_SGXLEPUBKEYHASH3, sgx_pubkey_hash[3]);
+	}
+}
+
+void vcpu_setup_sgx_lepubkeyhash(struct kvm_vcpu *vcpu)
+{
+	struct vcpu_vmx *vmx = to_vmx(vcpu);
+
+	memcpy(vmx->msr_ia32_sgxlepubkeyhash, sgx_pubkey_hash,
+	       sizeof(sgx_pubkey_hash));
+}
+
+/*
+ * ECREATE must be intercepted to enforce MISCSELECT, ATTRIBUTES and XFRM
+ * restrictions if the guest's allowed-1 settings diverge from hardware.
+ */
+static bool sgx_intercept_encls_ecreate(struct kvm_vcpu *vcpu)
+{
+	struct kvm_cpuid_entry2 *guest_cpuid;
+	u32 eax, ebx, ecx, edx;
+
+	if (!vcpu->kvm->arch.sgx_provisioning_allowed)
+		return true;
+
+	guest_cpuid = kvm_find_cpuid_entry(vcpu, 0x12, 0);
+	if (!guest_cpuid)
+		return true;
+
+	cpuid_count(0x12, 0, &eax, &ebx, &ecx, &edx);
+	if (guest_cpuid->ebx != ebx || guest_cpuid->edx != edx)
+		return true;
+
+	guest_cpuid = kvm_find_cpuid_entry(vcpu, 0x12, 1);
+	if (!guest_cpuid)
+		return true;
+
+	cpuid_count(0x12, 1, &eax, &ebx, &ecx, &edx);
+	if (guest_cpuid->eax != eax || guest_cpuid->ebx != ebx ||
+	    guest_cpuid->ecx != ecx || guest_cpuid->edx != edx)
+		return true;
+
+	return false;
+}
+
+void vmx_write_encls_bitmap(struct kvm_vcpu *vcpu, struct vmcs12 *vmcs12)
+{
+	/*
+	 * There is no software enable bit for SGX that is virtualized by
+	 * hardware, e.g. there's no CR4.SGXE, so when SGX is disabled in the
+	 * guest (either by the host or by the guest's BIOS) but enabled in the
+	 * host, trap all ENCLS leafs and inject #UD/#GP as needed to emulate
+	 * the expected system behavior for ENCLS.
+	 */
+	u64 bitmap = -1ull;
+
+	/* Nothing to do if hardware doesn't support SGX */
+	if (!cpu_has_vmx_encls_vmexit())
+		return;
+
+	if (guest_cpuid_has(vcpu, X86_FEATURE_SGX) &&
+	    sgx_enabled_in_guest_bios(vcpu)) {
+		if (guest_cpuid_has(vcpu, X86_FEATURE_SGX1)) {
+			bitmap &= ~GENMASK_ULL(ETRACK, ECREATE);
+			if (sgx_intercept_encls_ecreate(vcpu))
+				bitmap |= (1 << ECREATE);
+		}
+
+		if (guest_cpuid_has(vcpu, X86_FEATURE_SGX2))
+			bitmap &= ~GENMASK_ULL(EMODT, EAUG);
+
+		/*
+		 * Trap and execute EINIT if launch control is enabled in the
+		 * host using the guest's values for launch control MSRs, even
+		 * if the guest's values are fixed to hardware default values.
+		 * The MSRs are not loaded/saved on VM-Enter/VM-Exit as writing
+		 * the MSRs is extraordinarily expensive.
+		 */
+		if (boot_cpu_has(X86_FEATURE_SGX_LC))
+			bitmap |= (1 << EINIT);
+
+		if (!vmcs12 && is_guest_mode(vcpu))
+			vmcs12 = get_vmcs12(vcpu);
+		if (vmcs12 && nested_cpu_has_encls_exit(vmcs12))
+			bitmap |= vmcs12->encls_exiting_bitmap;
+	}
+	vmcs_write64(ENCLS_EXITING_BITMAP, bitmap);
+}
diff --git a/arch/x86/kvm/vmx/sgx.h b/arch/x86/kvm/vmx/sgx.h
new file mode 100644
index 000000000000..a400888b376d
--- /dev/null
+++ b/arch/x86/kvm/vmx/sgx.h
@@ -0,0 +1,34 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef __KVM_X86_SGX_H
+#define __KVM_X86_SGX_H
+
+#include <linux/kvm_host.h>
+
+#include "capabilities.h"
+#include "vmx_ops.h"
+
+#ifdef CONFIG_X86_SGX_KVM
+extern bool __read_mostly enable_sgx;
+
+int handle_encls(struct kvm_vcpu *vcpu);
+
+void setup_default_sgx_lepubkeyhash(void);
+void vcpu_setup_sgx_lepubkeyhash(struct kvm_vcpu *vcpu);
+
+void vmx_write_encls_bitmap(struct kvm_vcpu *vcpu, struct vmcs12 *vmcs12);
+#else
+#define enable_sgx 0
+
+static inline void setup_default_sgx_lepubkeyhash(void) { }
+static inline void vcpu_setup_sgx_lepubkeyhash(struct kvm_vcpu *vcpu) { }
+
+static inline void vmx_write_encls_bitmap(struct kvm_vcpu *vcpu,
+					  struct vmcs12 *vmcs12)
+{
+	/* Nothing to do if hardware doesn't support SGX */
+	if (cpu_has_vmx_encls_vmexit())
+		vmcs_write64(ENCLS_EXITING_BITMAP, -1ull);
+}
+#endif
+
+#endif /* __KVM_X86_SGX_H */
diff --git a/arch/x86/kvm/vmx/vmcs12.c b/arch/x86/kvm/vmx/vmcs12.c
index c8e51c004f78..034adb6404dc 100644
--- a/arch/x86/kvm/vmx/vmcs12.c
+++ b/arch/x86/kvm/vmx/vmcs12.c
@@ -50,6 +50,7 @@ const unsigned short vmcs_field_to_offset_table[] = {
 	FIELD64(VMREAD_BITMAP, vmread_bitmap),
 	FIELD64(VMWRITE_BITMAP, vmwrite_bitmap),
 	FIELD64(XSS_EXIT_BITMAP, xss_exit_bitmap),
+	FIELD64(ENCLS_EXITING_BITMAP, encls_exiting_bitmap),
 	FIELD64(GUEST_PHYSICAL_ADDRESS, guest_physical_address),
 	FIELD64(VMCS_LINK_POINTER, vmcs_link_pointer),
 	FIELD64(GUEST_IA32_DEBUGCTL, guest_ia32_debugctl),
diff --git a/arch/x86/kvm/vmx/vmcs12.h b/arch/x86/kvm/vmx/vmcs12.h
index 80232daf00ff..13494956d0e9 100644
--- a/arch/x86/kvm/vmx/vmcs12.h
+++ b/arch/x86/kvm/vmx/vmcs12.h
@@ -69,7 +69,8 @@ struct __packed vmcs12 {
 	u64 vm_function_control;
 	u64 eptp_list_address;
 	u64 pml_address;
-	u64 padding64[3]; /* room for future expansion */
+	u64 encls_exiting_bitmap;
+	u64 padding64[2]; /* room for future expansion */
 	/*
 	 * To allow migration of L1 (complete with its L2 guests) between
 	 * machines of different natural widths (32 or 64 bit), we cannot have
@@ -256,6 +257,7 @@ static inline void vmx_check_vmcs12_offsets(void)
 	CHECK_OFFSET(vm_function_control, 296);
 	CHECK_OFFSET(eptp_list_address, 304);
 	CHECK_OFFSET(pml_address, 312);
+	CHECK_OFFSET(encls_exiting_bitmap, 320);
 	CHECK_OFFSET(cr0_guest_host_mask, 344);
 	CHECK_OFFSET(cr4_guest_host_mask, 352);
 	CHECK_OFFSET(cr0_read_shadow, 360);
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index bcbf0d2139e9..cbe0cdade38a 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -57,6 +57,7 @@
 #include "mmu.h"
 #include "nested.h"
 #include "pmu.h"
+#include "sgx.h"
 #include "trace.h"
 #include "vmcs.h"
 #include "vmcs12.h"
@@ -156,9 +157,11 @@ static u32 vmx_possible_passthrough_msrs[MAX_POSSIBLE_PASSTHROUGH_MSRS] = {
 	MSR_IA32_SPEC_CTRL,
 	MSR_IA32_PRED_CMD,
 	MSR_IA32_TSC,
+#ifdef CONFIG_X86_64
 	MSR_FS_BASE,
 	MSR_GS_BASE,
 	MSR_KERNEL_GS_BASE,
+#endif
 	MSR_IA32_SYSENTER_CS,
 	MSR_IA32_SYSENTER_ESP,
 	MSR_IA32_SYSENTER_EIP,
@@ -361,8 +364,6 @@ static const struct kernel_param_ops vmentry_l1d_flush_ops = {
 module_param_cb(vmentry_l1d_flush, &vmentry_l1d_flush_ops, NULL, 0644);
 
 static u32 vmx_segment_access_rights(struct kvm_segment *var);
-static __always_inline void vmx_disable_intercept_for_msr(struct kvm_vcpu *vcpu,
-							  u32 msr, int type);
 
 void vmx_vmexit(void);
 
@@ -472,26 +473,6 @@ static const u32 vmx_uret_msrs_list[] = {
 static bool __read_mostly enlightened_vmcs = true;
 module_param(enlightened_vmcs, bool, 0444);
 
-/* check_ept_pointer() should be under protection of ept_pointer_lock. */
-static void check_ept_pointer_match(struct kvm *kvm)
-{
-	struct kvm_vcpu *vcpu;
-	u64 tmp_eptp = INVALID_PAGE;
-	int i;
-
-	kvm_for_each_vcpu(i, vcpu, kvm) {
-		if (!VALID_PAGE(tmp_eptp)) {
-			tmp_eptp = to_vmx(vcpu)->ept_pointer;
-		} else if (tmp_eptp != to_vmx(vcpu)->ept_pointer) {
-			to_kvm_vmx(kvm)->ept_pointers_match
-				= EPT_POINTERS_MISMATCH;
-			return;
-		}
-	}
-
-	to_kvm_vmx(kvm)->ept_pointers_match = EPT_POINTERS_MATCH;
-}
-
 static int kvm_fill_hv_flush_list_func(struct hv_guest_mapping_flush_list *flush,
 		void *data)
 {
@@ -501,47 +482,70 @@ static int kvm_fill_hv_flush_list_func(struct hv_guest_mapping_flush_list *flush
 			range->pages);
 }
 
-static inline int __hv_remote_flush_tlb_with_range(struct kvm *kvm,
-		struct kvm_vcpu *vcpu, struct kvm_tlb_range *range)
+static inline int hv_remote_flush_root_ept(hpa_t root_ept,
+					   struct kvm_tlb_range *range)
 {
-	u64 ept_pointer = to_vmx(vcpu)->ept_pointer;
-
-	/*
-	 * FLUSH_GUEST_PHYSICAL_ADDRESS_SPACE hypercall needs address
-	 * of the base of EPT PML4 table, strip off EPT configuration
-	 * information.
-	 */
 	if (range)
-		return hyperv_flush_guest_mapping_range(ept_pointer & PAGE_MASK,
+		return hyperv_flush_guest_mapping_range(root_ept,
 				kvm_fill_hv_flush_list_func, (void *)range);
 	else
-		return hyperv_flush_guest_mapping(ept_pointer & PAGE_MASK);
+		return hyperv_flush_guest_mapping(root_ept);
 }
 
 static int hv_remote_flush_tlb_with_range(struct kvm *kvm,
 		struct kvm_tlb_range *range)
 {
+	struct kvm_vmx *kvm_vmx = to_kvm_vmx(kvm);
 	struct kvm_vcpu *vcpu;
-	int ret = 0, i;
+	int ret = 0, i, nr_unique_valid_roots;
+	hpa_t root;
 
-	spin_lock(&to_kvm_vmx(kvm)->ept_pointer_lock);
+	spin_lock(&kvm_vmx->hv_root_ept_lock);
 
-	if (to_kvm_vmx(kvm)->ept_pointers_match == EPT_POINTERS_CHECK)
-		check_ept_pointer_match(kvm);
+	if (!VALID_PAGE(kvm_vmx->hv_root_ept)) {
+		nr_unique_valid_roots = 0;
 
-	if (to_kvm_vmx(kvm)->ept_pointers_match != EPT_POINTERS_MATCH) {
+		/*
+		 * Flush all valid roots, and see if all vCPUs have converged
+		 * on a common root, in which case future flushes can skip the
+		 * loop and flush the common root.
+		 */
 		kvm_for_each_vcpu(i, vcpu, kvm) {
-			/* If ept_pointer is invalid pointer, bypass flush request. */
-			if (VALID_PAGE(to_vmx(vcpu)->ept_pointer))
-				ret |= __hv_remote_flush_tlb_with_range(
-					kvm, vcpu, range);
+			root = to_vmx(vcpu)->hv_root_ept;
+			if (!VALID_PAGE(root) || root == kvm_vmx->hv_root_ept)
+				continue;
+
+			/*
+			 * Set the tracked root to the first valid root.  Keep
+			 * this root for the entirety of the loop even if more
+			 * roots are encountered as a low effort optimization
+			 * to avoid flushing the same (first) root again.
+			 */
+			if (++nr_unique_valid_roots == 1)
+				kvm_vmx->hv_root_ept = root;
+
+			if (!ret)
+				ret = hv_remote_flush_root_ept(root, range);
+
+			/*
+			 * Stop processing roots if a failure occurred and
+			 * multiple valid roots have already been detected.
+			 */
+			if (ret && nr_unique_valid_roots > 1)
+				break;
 		}
+
+		/*
+		 * The optimized flush of a single root can't be used if there
+		 * are multiple valid roots (obviously).
+		 */
+		if (nr_unique_valid_roots > 1)
+			kvm_vmx->hv_root_ept = INVALID_PAGE;
 	} else {
-		ret = __hv_remote_flush_tlb_with_range(kvm,
-				kvm_get_vcpu(kvm, 0), range);
+		ret = hv_remote_flush_root_ept(kvm_vmx->hv_root_ept, range);
 	}
 
-	spin_unlock(&to_kvm_vmx(kvm)->ept_pointer_lock);
+	spin_unlock(&kvm_vmx->hv_root_ept_lock);
 	return ret;
 }
 static int hv_remote_flush_tlb(struct kvm *kvm)
@@ -559,7 +563,7 @@ static int hv_enable_direct_tlbflush(struct kvm_vcpu *vcpu)
 	 * evmcs in singe VM shares same assist page.
 	 */
 	if (!*p_hv_pa_pg)
-		*p_hv_pa_pg = kzalloc(PAGE_SIZE, GFP_KERNEL);
+		*p_hv_pa_pg = kzalloc(PAGE_SIZE, GFP_KERNEL_ACCOUNT);
 
 	if (!*p_hv_pa_pg)
 		return -ENOMEM;
@@ -576,6 +580,21 @@ static int hv_enable_direct_tlbflush(struct kvm_vcpu *vcpu)
 
 #endif /* IS_ENABLED(CONFIG_HYPERV) */
 
+static void hv_track_root_ept(struct kvm_vcpu *vcpu, hpa_t root_ept)
+{
+#if IS_ENABLED(CONFIG_HYPERV)
+	struct kvm_vmx *kvm_vmx = to_kvm_vmx(vcpu->kvm);
+
+	if (kvm_x86_ops.tlb_remote_flush == hv_remote_flush_tlb) {
+		spin_lock(&kvm_vmx->hv_root_ept_lock);
+		to_vmx(vcpu)->hv_root_ept = root_ept;
+		if (root_ept != kvm_vmx->hv_root_ept)
+			kvm_vmx->hv_root_ept = INVALID_PAGE;
+		spin_unlock(&kvm_vmx->hv_root_ept_lock);
+	}
+#endif
+}
+
 /*
  * Comment's format: document - errata name - stepping - processor name.
  * Refer from
@@ -1570,12 +1589,25 @@ static int vmx_rtit_ctl_check(struct kvm_vcpu *vcpu, u64 data)
 
 static bool vmx_can_emulate_instruction(struct kvm_vcpu *vcpu, void *insn, int insn_len)
 {
+	/*
+	 * Emulation of instructions in SGX enclaves is impossible as RIP does
+	 * not point  tthe failing instruction, and even if it did, the code
+	 * stream is inaccessible.  Inject #UD instead of exiting to userspace
+	 * so that guest userspace can't DoS the guest simply by triggering
+	 * emulation (enclaves are CPL3 only).
+	 */
+	if (to_vmx(vcpu)->exit_reason.enclave_mode) {
+		kvm_queue_exception(vcpu, UD_VECTOR);
+		return false;
+	}
 	return true;
 }
 
 static int skip_emulated_instruction(struct kvm_vcpu *vcpu)
 {
+	union vmx_exit_reason exit_reason = to_vmx(vcpu)->exit_reason;
 	unsigned long rip, orig_rip;
+	u32 instr_len;
 
 	/*
 	 * Using VMCS.VM_EXIT_INSTRUCTION_LEN on EPT misconfig depends on
@@ -1586,9 +1618,33 @@ static int skip_emulated_instruction(struct kvm_vcpu *vcpu)
 	 * i.e. we end up advancing IP with some random value.
 	 */
 	if (!static_cpu_has(X86_FEATURE_HYPERVISOR) ||
-	    to_vmx(vcpu)->exit_reason.basic != EXIT_REASON_EPT_MISCONFIG) {
+	    exit_reason.basic != EXIT_REASON_EPT_MISCONFIG) {
+		instr_len = vmcs_read32(VM_EXIT_INSTRUCTION_LEN);
+
+		/*
+		 * Emulating an enclave's instructions isn't supported as KVM
+		 * cannot access the enclave's memory or its true RIP, e.g. the
+		 * vmcs.GUEST_RIP points at the exit point of the enclave, not
+		 * the RIP that actually triggered the VM-Exit.  But, because
+		 * most instructions that cause VM-Exit will #UD in an enclave,
+		 * most instruction-based VM-Exits simply do not occur.
+		 *
+		 * There are a few exceptions, notably the debug instructions
+		 * INT1ICEBRK and INT3, as they are allowed in debug enclaves
+		 * and generate #DB/#BP as expected, which KVM might intercept.
+		 * But again, the CPU does the dirty work and saves an instr
+		 * length of zero so VMMs don't shoot themselves in the foot.
+		 * WARN if KVM tries to skip a non-zero length instruction on
+		 * a VM-Exit from an enclave.
+		 */
+		if (!instr_len)
+			goto rip_updated;
+
+		WARN(exit_reason.enclave_mode,
+		     "KVM: skipping instruction after SGX enclave VM-Exit");
+
 		orig_rip = kvm_rip_read(vcpu);
-		rip = orig_rip + vmcs_read32(VM_EXIT_INSTRUCTION_LEN);
+		rip = orig_rip + instr_len;
 #ifdef CONFIG_X86_64
 		/*
 		 * We need to mask out the high 32 bits of RIP if not in 64-bit
@@ -1604,6 +1660,7 @@ static int skip_emulated_instruction(struct kvm_vcpu *vcpu)
 			return 0;
 	}
 
+rip_updated:
 	/* skipping an emulated instruction also counts */
 	vmx_set_interrupt_shadow(vcpu, 0);
 
@@ -1865,6 +1922,13 @@ static int vmx_get_msr(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
 	case MSR_IA32_FEAT_CTL:
 		msr_info->data = vmx->msr_ia32_feature_control;
 		break;
+	case MSR_IA32_SGXLEPUBKEYHASH0 ... MSR_IA32_SGXLEPUBKEYHASH3:
+		if (!msr_info->host_initiated &&
+		    !guest_cpuid_has(vcpu, X86_FEATURE_SGX_LC))
+			return 1;
+		msr_info->data = to_vmx(vcpu)->msr_ia32_sgxlepubkeyhash
+			[msr_info->index - MSR_IA32_SGXLEPUBKEYHASH0];
+		break;
 	case MSR_IA32_VMX_BASIC ... MSR_IA32_VMX_VMFUNC:
 		if (!nested_vmx_allowed(vcpu))
 			return 1;
@@ -2158,6 +2222,29 @@ static int vmx_set_msr(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
 		vmx->msr_ia32_feature_control = data;
 		if (msr_info->host_initiated && data == 0)
 			vmx_leave_nested(vcpu);
+
+		/* SGX may be enabled/disabled by guest's firmware */
+		vmx_write_encls_bitmap(vcpu, NULL);
+		break;
+	case MSR_IA32_SGXLEPUBKEYHASH0 ... MSR_IA32_SGXLEPUBKEYHASH3:
+		/*
+		 * On real hardware, the LE hash MSRs are writable before
+		 * the firmware sets bit 0 in MSR 0x7a ("activating" SGX),
+		 * at which point SGX related bits in IA32_FEATURE_CONTROL
+		 * become writable.
+		 *
+		 * KVM does not emulate SGX activation for simplicity, so
+		 * allow writes to the LE hash MSRs if IA32_FEATURE_CONTROL
+		 * is unlocked.  This is technically not architectural
+		 * behavior, but it's close enough.
+		 */
+		if (!msr_info->host_initiated &&
+		    (!guest_cpuid_has(vcpu, X86_FEATURE_SGX_LC) ||
+		    ((vmx->msr_ia32_feature_control & FEAT_CTL_LOCKED) &&
+		    !(vmx->msr_ia32_feature_control & FEAT_CTL_SGX_LC_ENABLED))))
+			return 1;
+		vmx->msr_ia32_sgxlepubkeyhash
+			[msr_index - MSR_IA32_SGXLEPUBKEYHASH0] = data;
 		break;
 	case MSR_IA32_VMX_BASIC ... MSR_IA32_VMX_VMFUNC:
 		if (!msr_info->host_initiated)
@@ -3088,8 +3175,7 @@ static int vmx_get_max_tdp_level(void)
 	return 4;
 }
 
-u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa,
-		   int root_level)
+u64 construct_eptp(struct kvm_vcpu *vcpu, hpa_t root_hpa, int root_level)
 {
 	u64 eptp = VMX_EPTP_MT_WB;
 
@@ -3098,13 +3184,13 @@ u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa,
 	if (enable_ept_ad_bits &&
 	    (!is_guest_mode(vcpu) || nested_ept_ad_enabled(vcpu)))
 		eptp |= VMX_EPTP_AD_ENABLE_BIT;
-	eptp |= (root_hpa & PAGE_MASK);
+	eptp |= root_hpa;
 
 	return eptp;
 }
 
-static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long pgd,
-			     int pgd_level)
+static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa,
+			     int root_level)
 {
 	struct kvm *kvm = vcpu->kvm;
 	bool update_guest_cr3 = true;
@@ -3112,16 +3198,10 @@ static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long pgd,
 	u64 eptp;
 
 	if (enable_ept) {
-		eptp = construct_eptp(vcpu, pgd, pgd_level);
+		eptp = construct_eptp(vcpu, root_hpa, root_level);
 		vmcs_write64(EPT_POINTER, eptp);
 
-		if (kvm_x86_ops.tlb_remote_flush) {
-			spin_lock(&to_kvm_vmx(kvm)->ept_pointer_lock);
-			to_vmx(vcpu)->ept_pointer = eptp;
-			to_kvm_vmx(kvm)->ept_pointers_match
-				= EPT_POINTERS_CHECK;
-			spin_unlock(&to_kvm_vmx(kvm)->ept_pointer_lock);
-		}
+		hv_track_root_ept(vcpu, root_hpa);
 
 		if (!enable_unrestricted_guest && !is_paging(vcpu))
 			guest_cr3 = to_kvm_vmx(kvm)->ept_identity_map_addr;
@@ -3131,7 +3211,7 @@ static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long pgd,
 			update_guest_cr3 = false;
 		vmx_ept_load_pdptrs(vcpu);
 	} else {
-		guest_cr3 = pgd;
+		guest_cr3 = root_hpa | kvm_get_active_pcid(vcpu);
 	}
 
 	if (update_guest_cr3)
@@ -3738,8 +3818,7 @@ static void vmx_set_msr_bitmap_write(ulong *msr_bitmap, u32 msr)
 		__set_bit(msr & 0x1fff, msr_bitmap + 0xc00 / f);
 }
 
-static __always_inline void vmx_disable_intercept_for_msr(struct kvm_vcpu *vcpu,
-							  u32 msr, int type)
+void vmx_disable_intercept_for_msr(struct kvm_vcpu *vcpu, u32 msr, int type)
 {
 	struct vcpu_vmx *vmx = to_vmx(vcpu);
 	unsigned long *msr_bitmap = vmx->vmcs01.msr_bitmap;
@@ -3784,8 +3863,7 @@ static __always_inline void vmx_disable_intercept_for_msr(struct kvm_vcpu *vcpu,
 		vmx_clear_msr_bitmap_write(msr_bitmap, msr);
 }
 
-static __always_inline void vmx_enable_intercept_for_msr(struct kvm_vcpu *vcpu,
-							 u32 msr, int type)
+void vmx_enable_intercept_for_msr(struct kvm_vcpu *vcpu, u32 msr, int type)
 {
 	struct vcpu_vmx *vmx = to_vmx(vcpu);
 	unsigned long *msr_bitmap = vmx->vmcs01.msr_bitmap;
@@ -3818,15 +3896,6 @@ static __always_inline void vmx_enable_intercept_for_msr(struct kvm_vcpu *vcpu,
 		vmx_set_msr_bitmap_write(msr_bitmap, msr);
 }
 
-void vmx_set_intercept_for_msr(struct kvm_vcpu *vcpu,
-						      u32 msr, int type, bool value)
-{
-	if (value)
-		vmx_enable_intercept_for_msr(vcpu, msr, type);
-	else
-		vmx_disable_intercept_for_msr(vcpu, msr, type);
-}
-
 static u8 vmx_msr_bitmap_mode(struct kvm_vcpu *vcpu)
 {
 	u8 mode = 0;
@@ -4314,15 +4383,6 @@ static void vmx_compute_secondary_exec_control(struct vcpu_vmx *vmx)
 	vmx->secondary_exec_control = exec_control;
 }
 
-static void ept_set_mmio_spte_mask(void)
-{
-	/*
-	 * EPT Misconfigurations can be generated if the value of bits 2:0
-	 * of an EPT paging-structure entry is 110b (write/execute).
-	 */
-	kvm_mmu_set_mmio_spte_mask(VMX_EPT_MISCONFIG_WX_VALUE, 0);
-}
-
 #define VMX_XSS_EXIT_BITMAP 0
 
 /*
@@ -4410,8 +4470,7 @@ static void init_vmcs(struct vcpu_vmx *vmx)
 		vmcs_write16(GUEST_PML_INDEX, PML_ENTITY_NUM - 1);
 	}
 
-	if (cpu_has_vmx_encls_vmexit())
-		vmcs_write64(ENCLS_EXITING_BITMAP, -1ull);
+	vmx_write_encls_bitmap(&vmx->vcpu, NULL);
 
 	if (vmx_pt_mode_is_host_guest()) {
 		memset(&vmx->pt_desc, 0, sizeof(vmx->pt_desc));
@@ -5020,7 +5079,7 @@ static int handle_cr(struct kvm_vcpu *vcpu)
 	reg = (exit_qualification >> 8) & 15;
 	switch ((exit_qualification >> 4) & 3) {
 	case 0: /* mov to cr */
-		val = kvm_register_readl(vcpu, reg);
+		val = kvm_register_read(vcpu, reg);
 		trace_kvm_cr_write(cr, val);
 		switch (cr) {
 		case 0:
@@ -5143,7 +5202,7 @@ static int handle_dr(struct kvm_vcpu *vcpu)
 		kvm_register_write(vcpu, reg, val);
 		err = 0;
 	} else {
-		err = kvm_set_dr(vcpu, dr, kvm_register_readl(vcpu, reg));
+		err = kvm_set_dr(vcpu, dr, kvm_register_read(vcpu, reg));
 	}
 
 out:
@@ -5184,17 +5243,6 @@ static int handle_interrupt_window(struct kvm_vcpu *vcpu)
 	return 1;
 }
 
-static int handle_vmcall(struct kvm_vcpu *vcpu)
-{
-	return kvm_emulate_hypercall(vcpu);
-}
-
-static int handle_invd(struct kvm_vcpu *vcpu)
-{
-	/* Treat an INVD instruction as a NOP and just skip it. */
-	return kvm_skip_emulated_instruction(vcpu);
-}
-
 static int handle_invlpg(struct kvm_vcpu *vcpu)
 {
 	unsigned long exit_qualification = vmx_get_exit_qual(vcpu);
@@ -5203,28 +5251,6 @@ static int handle_invlpg(struct kvm_vcpu *vcpu)
 	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int handle_rdpmc(struct kvm_vcpu *vcpu)
-{
-	int err;
-
-	err = kvm_rdpmc(vcpu);
-	return kvm_complete_insn_gp(vcpu, err);
-}
-
-static int handle_wbinvd(struct kvm_vcpu *vcpu)
-{
-	return kvm_emulate_wbinvd(vcpu);
-}
-
-static int handle_xsetbv(struct kvm_vcpu *vcpu)
-{
-	u64 new_bv = kvm_read_edx_eax(vcpu);
-	u32 index = kvm_rcx_read(vcpu);
-
-	int err = kvm_set_xcr(vcpu, index, new_bv);
-	return kvm_complete_insn_gp(vcpu, err);
-}
-
 static int handle_apic_access(struct kvm_vcpu *vcpu)
 {
 	if (likely(fasteoi)) {
@@ -5361,7 +5387,7 @@ static int handle_ept_violation(struct kvm_vcpu *vcpu)
 			EPT_VIOLATION_EXECUTABLE))
 		      ? PFERR_PRESENT_MASK : 0;
 
-	error_code |= (exit_qualification & 0x100) != 0 ?
+	error_code |= (exit_qualification & EPT_VIOLATION_GVA_TRANSLATED) != 0 ?
 	       PFERR_GUEST_FINAL_MASK : PFERR_GUEST_PAGE_MASK;
 
 	vcpu->arch.exit_qualification = exit_qualification;
@@ -5384,6 +5410,9 @@ static int handle_ept_misconfig(struct kvm_vcpu *vcpu)
 {
 	gpa_t gpa;
 
+	if (!vmx_can_emulate_instruction(vcpu, NULL, 0))
+		return 1;
+
 	/*
 	 * A nested guest cannot optimize MMIO vmexits, because we have an
 	 * nGPA here instead of the required GPA.
@@ -5485,18 +5514,6 @@ static void shrink_ple_window(struct kvm_vcpu *vcpu)
 	}
 }
 
-static void vmx_enable_tdp(void)
-{
-	kvm_mmu_set_mask_ptes(VMX_EPT_READABLE_MASK,
-		enable_ept_ad_bits ? VMX_EPT_ACCESS_BIT : 0ull,
-		enable_ept_ad_bits ? VMX_EPT_DIRTY_BIT : 0ull,
-		0ull, VMX_EPT_EXECUTABLE_MASK,
-		cpu_has_vmx_ept_execute_only() ? 0ull : VMX_EPT_READABLE_MASK,
-		VMX_EPT_RWX_MASK, 0ull);
-
-	ept_set_mmio_spte_mask();
-}
-
 /*
  * Indicate a busy-waiting vcpu in spinlock. We do not enable the PAUSE
  * exiting, so only get here on cpu with PAUSE-Loop-Exiting.
@@ -5516,34 +5533,11 @@ static int handle_pause(struct kvm_vcpu *vcpu)
 	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int handle_nop(struct kvm_vcpu *vcpu)
-{
-	return kvm_skip_emulated_instruction(vcpu);
-}
-
-static int handle_mwait(struct kvm_vcpu *vcpu)
-{
-	printk_once(KERN_WARNING "kvm: MWAIT instruction emulated as NOP!\n");
-	return handle_nop(vcpu);
-}
-
-static int handle_invalid_op(struct kvm_vcpu *vcpu)
-{
-	kvm_queue_exception(vcpu, UD_VECTOR);
-	return 1;
-}
-
 static int handle_monitor_trap(struct kvm_vcpu *vcpu)
 {
 	return 1;
 }
 
-static int handle_monitor(struct kvm_vcpu *vcpu)
-{
-	printk_once(KERN_WARNING "kvm: MONITOR instruction emulated as NOP!\n");
-	return handle_nop(vcpu);
-}
-
 static int handle_invpcid(struct kvm_vcpu *vcpu)
 {
 	u32 vmx_instruction_info;
@@ -5560,7 +5554,7 @@ static int handle_invpcid(struct kvm_vcpu *vcpu)
 	}
 
 	vmx_instruction_info = vmcs_read32(VMX_INSTRUCTION_INFO);
-	type = kvm_register_readl(vcpu, (vmx_instruction_info >> 28) & 0xf);
+	type = kvm_register_read(vcpu, (vmx_instruction_info >> 28) & 0xf);
 
 	if (type > 3) {
 		kvm_inject_gp(vcpu, 0);
@@ -5632,16 +5626,18 @@ static int handle_vmx_instruction(struct kvm_vcpu *vcpu)
 	return 1;
 }
 
+#ifndef CONFIG_X86_SGX_KVM
 static int handle_encls(struct kvm_vcpu *vcpu)
 {
 	/*
-	 * SGX virtualization is not yet supported.  There is no software
-	 * enable bit for SGX, so we have to trap ENCLS and inject a #UD
-	 * to prevent the guest from executing ENCLS.
+	 * SGX virtualization is disabled.  There is no software enable bit for
+	 * SGX, so KVM intercepts all ENCLS leafs and injects a #UD to prevent
+	 * the guest from executing ENCLS (when SGX is supported by hardware).
 	 */
 	kvm_queue_exception(vcpu, UD_VECTOR);
 	return 1;
 }
+#endif /* CONFIG_X86_SGX_KVM */
 
 static int handle_bus_lock_vmexit(struct kvm_vcpu *vcpu)
 {
@@ -5668,10 +5664,10 @@ static int (*kvm_vmx_exit_handlers[])(struct kvm_vcpu *vcpu) = {
 	[EXIT_REASON_MSR_WRITE]               = kvm_emulate_wrmsr,
 	[EXIT_REASON_INTERRUPT_WINDOW]        = handle_interrupt_window,
 	[EXIT_REASON_HLT]                     = kvm_emulate_halt,
-	[EXIT_REASON_INVD]		      = handle_invd,
+	[EXIT_REASON_INVD]		      = kvm_emulate_invd,
 	[EXIT_REASON_INVLPG]		      = handle_invlpg,
-	[EXIT_REASON_RDPMC]                   = handle_rdpmc,
-	[EXIT_REASON_VMCALL]                  = handle_vmcall,
+	[EXIT_REASON_RDPMC]                   = kvm_emulate_rdpmc,
+	[EXIT_REASON_VMCALL]                  = kvm_emulate_hypercall,
 	[EXIT_REASON_VMCLEAR]		      = handle_vmx_instruction,
 	[EXIT_REASON_VMLAUNCH]		      = handle_vmx_instruction,
 	[EXIT_REASON_VMPTRLD]		      = handle_vmx_instruction,
@@ -5685,8 +5681,8 @@ static int (*kvm_vmx_exit_handlers[])(struct kvm_vcpu *vcpu) = {
 	[EXIT_REASON_APIC_ACCESS]             = handle_apic_access,
 	[EXIT_REASON_APIC_WRITE]              = handle_apic_write,
 	[EXIT_REASON_EOI_INDUCED]             = handle_apic_eoi_induced,
-	[EXIT_REASON_WBINVD]                  = handle_wbinvd,
-	[EXIT_REASON_XSETBV]                  = handle_xsetbv,
+	[EXIT_REASON_WBINVD]                  = kvm_emulate_wbinvd,
+	[EXIT_REASON_XSETBV]                  = kvm_emulate_xsetbv,
 	[EXIT_REASON_TASK_SWITCH]             = handle_task_switch,
 	[EXIT_REASON_MCE_DURING_VMENTRY]      = handle_machine_check,
 	[EXIT_REASON_GDTR_IDTR]		      = handle_desc,
@@ -5694,13 +5690,13 @@ static int (*kvm_vmx_exit_handlers[])(struct kvm_vcpu *vcpu) = {
 	[EXIT_REASON_EPT_VIOLATION]	      = handle_ept_violation,
 	[EXIT_REASON_EPT_MISCONFIG]           = handle_ept_misconfig,
 	[EXIT_REASON_PAUSE_INSTRUCTION]       = handle_pause,
-	[EXIT_REASON_MWAIT_INSTRUCTION]	      = handle_mwait,
+	[EXIT_REASON_MWAIT_INSTRUCTION]	      = kvm_emulate_mwait,
 	[EXIT_REASON_MONITOR_TRAP_FLAG]       = handle_monitor_trap,
-	[EXIT_REASON_MONITOR_INSTRUCTION]     = handle_monitor,
+	[EXIT_REASON_MONITOR_INSTRUCTION]     = kvm_emulate_monitor,
 	[EXIT_REASON_INVEPT]                  = handle_vmx_instruction,
 	[EXIT_REASON_INVVPID]                 = handle_vmx_instruction,
-	[EXIT_REASON_RDRAND]                  = handle_invalid_op,
-	[EXIT_REASON_RDSEED]                  = handle_invalid_op,
+	[EXIT_REASON_RDRAND]                  = kvm_handle_invalid_op,
+	[EXIT_REASON_RDSEED]                  = kvm_handle_invalid_op,
 	[EXIT_REASON_PML_FULL]		      = handle_pml_full,
 	[EXIT_REASON_INVPCID]                 = handle_invpcid,
 	[EXIT_REASON_VMFUNC]		      = handle_vmx_instruction,
@@ -5787,12 +5783,23 @@ static void vmx_dump_dtsel(char *name, uint32_t limit)
 	       vmcs_readl(limit + GUEST_GDTR_BASE - GUEST_GDTR_LIMIT));
 }
 
-void dump_vmcs(void)
+static void vmx_dump_msrs(char *name, struct vmx_msrs *m)
+{
+	unsigned int i;
+	struct vmx_msr_entry *e;
+
+	pr_err("MSR %s:\n", name);
+	for (i = 0, e = m->val; i < m->nr; ++i, ++e)
+		pr_err("  %2d: msr=0x%08x value=0x%016llx\n", i, e->index, e->value);
+}
+
+void dump_vmcs(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_vmx *vmx = to_vmx(vcpu);
 	u32 vmentry_ctl, vmexit_ctl;
 	u32 cpu_based_exec_ctrl, pin_based_exec_ctrl, secondary_exec_control;
 	unsigned long cr4;
-	u64 efer;
+	int efer_slot;
 
 	if (!dump_invalid_vmcs) {
 		pr_warn_ratelimited("set kvm_intel.dump_invalid_vmcs=1 to dump internal KVM state.\n");
@@ -5804,7 +5811,6 @@ void dump_vmcs(void)
 	cpu_based_exec_ctrl = vmcs_read32(CPU_BASED_VM_EXEC_CONTROL);
 	pin_based_exec_ctrl = vmcs_read32(PIN_BASED_VM_EXEC_CONTROL);
 	cr4 = vmcs_readl(GUEST_CR4);
-	efer = vmcs_read64(GUEST_IA32_EFER);
 	secondary_exec_control = 0;
 	if (cpu_has_secondary_exec_ctrls())
 		secondary_exec_control = vmcs_read32(SECONDARY_VM_EXEC_CONTROL);
@@ -5816,9 +5822,7 @@ void dump_vmcs(void)
 	pr_err("CR4: actual=0x%016lx, shadow=0x%016lx, gh_mask=%016lx\n",
 	       cr4, vmcs_readl(CR4_READ_SHADOW), vmcs_readl(CR4_GUEST_HOST_MASK));
 	pr_err("CR3 = 0x%016lx\n", vmcs_readl(GUEST_CR3));
-	if ((secondary_exec_control & SECONDARY_EXEC_ENABLE_EPT) &&
-	    (cr4 & X86_CR4_PAE) && !(efer & EFER_LMA))
-	{
+	if (cpu_has_vmx_ept()) {
 		pr_err("PDPTR0 = 0x%016llx  PDPTR1 = 0x%016llx\n",
 		       vmcs_read64(GUEST_PDPTR0), vmcs_read64(GUEST_PDPTR1));
 		pr_err("PDPTR2 = 0x%016llx  PDPTR3 = 0x%016llx\n",
@@ -5841,10 +5845,20 @@ void dump_vmcs(void)
 	vmx_dump_sel("LDTR:", GUEST_LDTR_SELECTOR);
 	vmx_dump_dtsel("IDTR:", GUEST_IDTR_LIMIT);
 	vmx_dump_sel("TR:  ", GUEST_TR_SELECTOR);
-	if ((vmexit_ctl & (VM_EXIT_SAVE_IA32_PAT | VM_EXIT_SAVE_IA32_EFER)) ||
-	    (vmentry_ctl & (VM_ENTRY_LOAD_IA32_PAT | VM_ENTRY_LOAD_IA32_EFER)))
-		pr_err("EFER =     0x%016llx  PAT = 0x%016llx\n",
-		       efer, vmcs_read64(GUEST_IA32_PAT));
+	efer_slot = vmx_find_loadstore_msr_slot(&vmx->msr_autoload.guest, MSR_EFER);
+	if (vmentry_ctl & VM_ENTRY_LOAD_IA32_EFER)
+		pr_err("EFER= 0x%016llx\n", vmcs_read64(GUEST_IA32_EFER));
+	else if (efer_slot >= 0)
+		pr_err("EFER= 0x%016llx (autoload)\n",
+		       vmx->msr_autoload.guest.val[efer_slot].value);
+	else if (vmentry_ctl & VM_ENTRY_IA32E_MODE)
+		pr_err("EFER= 0x%016llx (effective)\n",
+		       vcpu->arch.efer | (EFER_LMA | EFER_LME));
+	else
+		pr_err("EFER= 0x%016llx (effective)\n",
+		       vcpu->arch.efer & ~(EFER_LMA | EFER_LME));
+	if (vmentry_ctl & VM_ENTRY_LOAD_IA32_PAT)
+		pr_err("PAT = 0x%016llx\n", vmcs_read64(GUEST_IA32_PAT));
 	pr_err("DebugCtl = 0x%016llx  DebugExceptions = 0x%016lx\n",
 	       vmcs_read64(GUEST_IA32_DEBUGCTL),
 	       vmcs_readl(GUEST_PENDING_DBG_EXCEPTIONS));
@@ -5860,6 +5874,10 @@ void dump_vmcs(void)
 	if (secondary_exec_control & SECONDARY_EXEC_VIRTUAL_INTR_DELIVERY)
 		pr_err("InterruptStatus = %04x\n",
 		       vmcs_read16(GUEST_INTR_STATUS));
+	if (vmcs_read32(VM_ENTRY_MSR_LOAD_COUNT) > 0)
+		vmx_dump_msrs("guest autoload", &vmx->msr_autoload.guest);
+	if (vmcs_read32(VM_EXIT_MSR_STORE_COUNT) > 0)
+		vmx_dump_msrs("guest autostore", &vmx->msr_autostore.guest);
 
 	pr_err("*** Host State ***\n");
 	pr_err("RIP = 0x%016lx  RSP = 0x%016lx\n",
@@ -5881,14 +5899,16 @@ void dump_vmcs(void)
 	       vmcs_readl(HOST_IA32_SYSENTER_ESP),
 	       vmcs_read32(HOST_IA32_SYSENTER_CS),
 	       vmcs_readl(HOST_IA32_SYSENTER_EIP));
-	if (vmexit_ctl & (VM_EXIT_LOAD_IA32_PAT | VM_EXIT_LOAD_IA32_EFER))
-		pr_err("EFER = 0x%016llx  PAT = 0x%016llx\n",
-		       vmcs_read64(HOST_IA32_EFER),
-		       vmcs_read64(HOST_IA32_PAT));
+	if (vmexit_ctl & VM_EXIT_LOAD_IA32_EFER)
+		pr_err("EFER= 0x%016llx\n", vmcs_read64(HOST_IA32_EFER));
+	if (vmexit_ctl & VM_EXIT_LOAD_IA32_PAT)
+		pr_err("PAT = 0x%016llx\n", vmcs_read64(HOST_IA32_PAT));
 	if (cpu_has_load_perf_global_ctrl() &&
 	    vmexit_ctl & VM_EXIT_LOAD_IA32_PERF_GLOBAL_CTRL)
 		pr_err("PerfGlobCtl = 0x%016llx\n",
 		       vmcs_read64(HOST_IA32_PERF_GLOBAL_CTRL));
+	if (vmcs_read32(VM_EXIT_MSR_LOAD_COUNT) > 0)
+		vmx_dump_msrs("host autoload", &vmx->msr_autoload.host);
 
 	pr_err("*** Control State ***\n");
 	pr_err("PinBased=%08x CPUBased=%08x SecondaryExec=%08x\n",
@@ -5997,7 +6017,7 @@ static int __vmx_handle_exit(struct kvm_vcpu *vcpu, fastpath_t exit_fastpath)
 	}
 
 	if (exit_reason.failed_vmentry) {
-		dump_vmcs();
+		dump_vmcs(vcpu);
 		vcpu->run->exit_reason = KVM_EXIT_FAIL_ENTRY;
 		vcpu->run->fail_entry.hardware_entry_failure_reason
 			= exit_reason.full;
@@ -6006,7 +6026,7 @@ static int __vmx_handle_exit(struct kvm_vcpu *vcpu, fastpath_t exit_fastpath)
 	}
 
 	if (unlikely(vmx->fail)) {
-		dump_vmcs();
+		dump_vmcs(vcpu);
 		vcpu->run->exit_reason = KVM_EXIT_FAIL_ENTRY;
 		vcpu->run->fail_entry.hardware_entry_failure_reason
 			= vmcs_read32(VM_INSTRUCTION_ERROR);
@@ -6092,7 +6112,7 @@ static int __vmx_handle_exit(struct kvm_vcpu *vcpu, fastpath_t exit_fastpath)
 unexpected_vmexit:
 	vcpu_unimpl(vcpu, "vmx: unexpected exit reason 0x%x\n",
 		    exit_reason.full);
-	dump_vmcs();
+	dump_vmcs(vcpu);
 	vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
 	vcpu->run->internal.suberror =
 			KVM_INTERNAL_ERROR_UNEXPECTED_EXIT_REASON;
@@ -6938,9 +6958,11 @@ static int vmx_create_vcpu(struct kvm_vcpu *vcpu)
 	bitmap_fill(vmx->shadow_msr_intercept.write, MAX_POSSIBLE_PASSTHROUGH_MSRS);
 
 	vmx_disable_intercept_for_msr(vcpu, MSR_IA32_TSC, MSR_TYPE_R);
+#ifdef CONFIG_X86_64
 	vmx_disable_intercept_for_msr(vcpu, MSR_FS_BASE, MSR_TYPE_RW);
 	vmx_disable_intercept_for_msr(vcpu, MSR_GS_BASE, MSR_TYPE_RW);
 	vmx_disable_intercept_for_msr(vcpu, MSR_KERNEL_GS_BASE, MSR_TYPE_RW);
+#endif
 	vmx_disable_intercept_for_msr(vcpu, MSR_IA32_SYSENTER_CS, MSR_TYPE_RW);
 	vmx_disable_intercept_for_msr(vcpu, MSR_IA32_SYSENTER_ESP, MSR_TYPE_RW);
 	vmx_disable_intercept_for_msr(vcpu, MSR_IA32_SYSENTER_EIP, MSR_TYPE_RW);
@@ -6976,6 +6998,8 @@ static int vmx_create_vcpu(struct kvm_vcpu *vcpu)
 	else
 		memset(&vmx->nested.msrs, 0, sizeof(vmx->nested.msrs));
 
+	vcpu_setup_sgx_lepubkeyhash(vcpu);
+
 	vmx->nested.posted_intr_nv = -1;
 	vmx->nested.current_vmptr = -1ull;
 
@@ -6989,8 +7013,9 @@ static int vmx_create_vcpu(struct kvm_vcpu *vcpu)
 	vmx->pi_desc.nv = POSTED_INTR_VECTOR;
 	vmx->pi_desc.sn = 1;
 
-	vmx->ept_pointer = INVALID_PAGE;
-
+#if IS_ENABLED(CONFIG_HYPERV)
+	vmx->hv_root_ept = INVALID_PAGE;
+#endif
 	return 0;
 
 free_vmcs:
@@ -7007,7 +7032,9 @@ free_vpid:
 
 static int vmx_vm_init(struct kvm *kvm)
 {
-	spin_lock_init(&to_kvm_vmx(kvm)->ept_pointer_lock);
+#if IS_ENABLED(CONFIG_HYPERV)
+	spin_lock_init(&to_kvm_vmx(kvm)->hv_root_ept_lock);
+#endif
 
 	if (!ple_gap)
 		kvm->arch.pause_in_guest = true;
@@ -7302,6 +7329,19 @@ static void vmx_vcpu_after_set_cpuid(struct kvm_vcpu *vcpu)
 
 	set_cr4_guest_host_mask(vmx);
 
+	vmx_write_encls_bitmap(vcpu, NULL);
+	if (guest_cpuid_has(vcpu, X86_FEATURE_SGX))
+		vmx->msr_ia32_feature_control_valid_bits |= FEAT_CTL_SGX_ENABLED;
+	else
+		vmx->msr_ia32_feature_control_valid_bits &= ~FEAT_CTL_SGX_ENABLED;
+
+	if (guest_cpuid_has(vcpu, X86_FEATURE_SGX_LC))
+		vmx->msr_ia32_feature_control_valid_bits |=
+			FEAT_CTL_SGX_LC_ENABLED;
+	else
+		vmx->msr_ia32_feature_control_valid_bits &=
+			~FEAT_CTL_SGX_LC_ENABLED;
+
 	/* Refresh #PF interception to account for MAXPHYADDR changes. */
 	vmx_update_exception_bitmap(vcpu);
 }
@@ -7322,6 +7362,13 @@ static __init void vmx_set_cpu_caps(void)
 	if (vmx_pt_mode_is_host_guest())
 		kvm_cpu_cap_check_and_set(X86_FEATURE_INTEL_PT);
 
+	if (!enable_sgx) {
+		kvm_cpu_cap_clear(X86_FEATURE_SGX);
+		kvm_cpu_cap_clear(X86_FEATURE_SGX_LC);
+		kvm_cpu_cap_clear(X86_FEATURE_SGX1);
+		kvm_cpu_cap_clear(X86_FEATURE_SGX2);
+	}
+
 	if (vmx_umip_emulated())
 		kvm_cpu_cap_set(X86_FEATURE_UMIP);
 
@@ -7848,7 +7895,8 @@ static __init int hardware_setup(void)
 	set_bit(0, vmx_vpid_bitmap); /* 0 is reserved for host */
 
 	if (enable_ept)
-		vmx_enable_tdp();
+		kvm_mmu_set_ept_masks(enable_ept_ad_bits,
+				      cpu_has_vmx_ept_execute_only());
 
 	if (!enable_ept)
 		ept_lpage_level = 0;
@@ -7909,6 +7957,8 @@ static __init int hardware_setup(void)
 	if (!enable_ept || !cpu_has_vmx_intel_pt())
 		pt_mode = PT_MODE_SYSTEM;
 
+	setup_default_sgx_lepubkeyhash();
+
 	if (nested) {
 		nested_vmx_setup_ctls_msrs(&vmcs_config.nested,
 					   vmx_capability.ept);
diff --git a/arch/x86/kvm/vmx/vmx.h b/arch/x86/kvm/vmx/vmx.h
index 89da5e1251f1..008cb87ff088 100644
--- a/arch/x86/kvm/vmx/vmx.h
+++ b/arch/x86/kvm/vmx/vmx.h
@@ -325,7 +325,12 @@ struct vcpu_vmx {
 	 */
 	u64 msr_ia32_feature_control;
 	u64 msr_ia32_feature_control_valid_bits;
-	u64 ept_pointer;
+	/* SGX Launch Control public key hash */
+	u64 msr_ia32_sgxlepubkeyhash[4];
+
+#if IS_ENABLED(CONFIG_HYPERV)
+	u64 hv_root_ept;
+#endif
 
 	struct pt_desc pt_desc;
 	struct lbr_desc lbr_desc;
@@ -338,12 +343,6 @@ struct vcpu_vmx {
 	} shadow_msr_intercept;
 };
 
-enum ept_pointers_status {
-	EPT_POINTERS_CHECK = 0,
-	EPT_POINTERS_MATCH = 1,
-	EPT_POINTERS_MISMATCH = 2
-};
-
 struct kvm_vmx {
 	struct kvm kvm;
 
@@ -351,8 +350,10 @@ struct kvm_vmx {
 	bool ept_identity_pagetable_done;
 	gpa_t ept_identity_map_addr;
 
-	enum ept_pointers_status ept_pointers_match;
-	spinlock_t ept_pointer_lock;
+#if IS_ENABLED(CONFIG_HYPERV)
+	hpa_t hv_root_ept;
+	spinlock_t hv_root_ept_lock;
+#endif
 };
 
 bool nested_vmx_allowed(struct kvm_vcpu *vcpu);
@@ -376,8 +377,7 @@ void set_cr4_guest_host_mask(struct vcpu_vmx *vmx);
 void ept_save_pdptrs(struct kvm_vcpu *vcpu);
 void vmx_get_segment(struct kvm_vcpu *vcpu, struct kvm_segment *var, int seg);
 void vmx_set_segment(struct kvm_vcpu *vcpu, struct kvm_segment *var, int seg);
-u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa,
-		   int root_level);
+u64 construct_eptp(struct kvm_vcpu *vcpu, hpa_t root_hpa, int root_level);
 
 void vmx_update_exception_bitmap(struct kvm_vcpu *vcpu);
 void vmx_update_msr_bitmap(struct kvm_vcpu *vcpu);
@@ -392,8 +392,19 @@ void vmx_update_host_rsp(struct vcpu_vmx *vmx, unsigned long host_rsp);
 bool __vmx_vcpu_run(struct vcpu_vmx *vmx, unsigned long *regs, bool launched);
 int vmx_find_loadstore_msr_slot(struct vmx_msrs *m, u32 msr);
 void vmx_ept_load_pdptrs(struct kvm_vcpu *vcpu);
-void vmx_set_intercept_for_msr(struct kvm_vcpu *vcpu,
-	u32 msr, int type, bool value);
+
+void vmx_disable_intercept_for_msr(struct kvm_vcpu *vcpu, u32 msr, int type);
+void vmx_enable_intercept_for_msr(struct kvm_vcpu *vcpu, u32 msr, int type);
+
+static inline void vmx_set_intercept_for_msr(struct kvm_vcpu *vcpu, u32 msr,
+					     int type, bool value)
+{
+	if (value)
+		vmx_enable_intercept_for_msr(vcpu, msr, type);
+	else
+		vmx_disable_intercept_for_msr(vcpu, msr, type);
+}
+
 void vmx_update_cpu_dirty_logging(struct kvm_vcpu *vcpu);
 
 static inline u8 vmx_get_rvi(void)
@@ -543,6 +554,6 @@ static inline bool vmx_guest_state_valid(struct kvm_vcpu *vcpu)
 	return is_unrestricted_guest(vcpu) || __vmx_guest_state_valid(vcpu);
 }
 
-void dump_vmcs(void);
+void dump_vmcs(struct kvm_vcpu *vcpu);
 
 #endif /* __KVM_X86_VMX_H */
diff --git a/arch/x86/kvm/vmx/vmx_ops.h b/arch/x86/kvm/vmx/vmx_ops.h
index 692b0c31c9c8..164b64f65a8f 100644
--- a/arch/x86/kvm/vmx/vmx_ops.h
+++ b/arch/x86/kvm/vmx/vmx_ops.h
@@ -37,6 +37,10 @@ static __always_inline void vmcs_check32(unsigned long field)
 {
 	BUILD_BUG_ON_MSG(__builtin_constant_p(field) && ((field) & 0x6000) == 0,
 			 "32-bit accessor invalid for 16-bit field");
+	BUILD_BUG_ON_MSG(__builtin_constant_p(field) && ((field) & 0x6001) == 0x2000,
+			 "32-bit accessor invalid for 64-bit field");
+	BUILD_BUG_ON_MSG(__builtin_constant_p(field) && ((field) & 0x6001) == 0x2001,
+			 "32-bit accessor invalid for 64-bit high field");
 	BUILD_BUG_ON_MSG(__builtin_constant_p(field) && ((field) & 0x6000) == 0x6000,
 			 "32-bit accessor invalid for natural width field");
 }
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index efc7a82ab140..cebdaa1e3cf5 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -75,6 +75,7 @@
 #include <asm/tlbflush.h>
 #include <asm/intel_pt.h>
 #include <asm/emulate_prefix.h>
+#include <asm/sgx.h>
 #include <clocksource/hyperv_timer.h>
 
 #define CREATE_TRACE_POINTS
@@ -245,6 +246,9 @@ struct kvm_stats_debugfs_item debugfs_entries[] = {
 	VCPU_STAT("l1d_flush", l1d_flush),
 	VCPU_STAT("halt_poll_success_ns", halt_poll_success_ns),
 	VCPU_STAT("halt_poll_fail_ns", halt_poll_fail_ns),
+	VCPU_STAT("nested_run", nested_run),
+	VCPU_STAT("directed_yield_attempted", directed_yield_attempted),
+	VCPU_STAT("directed_yield_successful", directed_yield_successful),
 	VM_STAT("mmu_shadow_zapped", mmu_shadow_zapped),
 	VM_STAT("mmu_pte_write", mmu_pte_write),
 	VM_STAT("mmu_pde_zapped", mmu_pde_zapped),
@@ -543,8 +547,6 @@ static void kvm_multiple_exception(struct kvm_vcpu *vcpu,
 
 	if (!vcpu->arch.exception.pending && !vcpu->arch.exception.injected) {
 	queue:
-		if (has_error && !is_protmode(vcpu))
-			has_error = false;
 		if (reinject) {
 			/*
 			 * On vmentry, vcpu->arch.exception.pending is only
@@ -983,14 +985,17 @@ static int __kvm_set_xcr(struct kvm_vcpu *vcpu, u32 index, u64 xcr)
 	return 0;
 }
 
-int kvm_set_xcr(struct kvm_vcpu *vcpu, u32 index, u64 xcr)
+int kvm_emulate_xsetbv(struct kvm_vcpu *vcpu)
 {
-	if (static_call(kvm_x86_get_cpl)(vcpu) == 0)
-		return __kvm_set_xcr(vcpu, index, xcr);
+	if (static_call(kvm_x86_get_cpl)(vcpu) != 0 ||
+	    __kvm_set_xcr(vcpu, kvm_rcx_read(vcpu), kvm_read_edx_eax(vcpu))) {
+		kvm_inject_gp(vcpu, 0);
+		return 1;
+	}
 
-	return 1;
+	return kvm_skip_emulated_instruction(vcpu);
 }
-EXPORT_SYMBOL_GPL(kvm_set_xcr);
+EXPORT_SYMBOL_GPL(kvm_emulate_xsetbv);
 
 bool kvm_is_valid_cr4(struct kvm_vcpu *vcpu, unsigned long cr4)
 {
@@ -1072,10 +1077,15 @@ int kvm_set_cr3(struct kvm_vcpu *vcpu, unsigned long cr3)
 		return 0;
 	}
 
-	if (is_long_mode(vcpu) && kvm_vcpu_is_illegal_gpa(vcpu, cr3))
+	/*
+	 * Do not condition the GPA check on long mode, this helper is used to
+	 * stuff CR3, e.g. for RSM emulation, and there is no guarantee that
+	 * the current vCPU mode is accurate.
+	 */
+	if (kvm_vcpu_is_illegal_gpa(vcpu, cr3))
 		return 1;
-	else if (is_pae_paging(vcpu) &&
-		 !load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3))
+
+	if (is_pae_paging(vcpu) && !load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3))
 		return 1;
 
 	kvm_mmu_new_pgd(vcpu, cr3, skip_tlb_flush, skip_tlb_flush);
@@ -1191,20 +1201,21 @@ void kvm_get_dr(struct kvm_vcpu *vcpu, int dr, unsigned long *val)
 }
 EXPORT_SYMBOL_GPL(kvm_get_dr);
 
-bool kvm_rdpmc(struct kvm_vcpu *vcpu)
+int kvm_emulate_rdpmc(struct kvm_vcpu *vcpu)
 {
 	u32 ecx = kvm_rcx_read(vcpu);
 	u64 data;
-	int err;
 
-	err = kvm_pmu_rdpmc(vcpu, ecx, &data);
-	if (err)
-		return err;
+	if (kvm_pmu_rdpmc(vcpu, ecx, &data)) {
+		kvm_inject_gp(vcpu, 0);
+		return 1;
+	}
+
 	kvm_rax_write(vcpu, (u32)data);
 	kvm_rdx_write(vcpu, data >> 32);
-	return err;
+	return kvm_skip_emulated_instruction(vcpu);
 }
-EXPORT_SYMBOL_GPL(kvm_rdpmc);
+EXPORT_SYMBOL_GPL(kvm_emulate_rdpmc);
 
 /*
  * List of msr numbers which we expose to userspace through KVM_GET_MSRS
@@ -1791,6 +1802,40 @@ int kvm_emulate_wrmsr(struct kvm_vcpu *vcpu)
 }
 EXPORT_SYMBOL_GPL(kvm_emulate_wrmsr);
 
+int kvm_emulate_as_nop(struct kvm_vcpu *vcpu)
+{
+	return kvm_skip_emulated_instruction(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_as_nop);
+
+int kvm_emulate_invd(struct kvm_vcpu *vcpu)
+{
+	/* Treat an INVD instruction as a NOP and just skip it. */
+	return kvm_emulate_as_nop(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_invd);
+
+int kvm_emulate_mwait(struct kvm_vcpu *vcpu)
+{
+	pr_warn_once("kvm: MWAIT instruction emulated as NOP!\n");
+	return kvm_emulate_as_nop(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_mwait);
+
+int kvm_handle_invalid_op(struct kvm_vcpu *vcpu)
+{
+	kvm_queue_exception(vcpu, UD_VECTOR);
+	return 1;
+}
+EXPORT_SYMBOL_GPL(kvm_handle_invalid_op);
+
+int kvm_emulate_monitor(struct kvm_vcpu *vcpu)
+{
+	pr_warn_once("kvm: MONITOR instruction emulated as NOP!\n");
+	return kvm_emulate_as_nop(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_monitor);
+
 static inline bool kvm_vcpu_exit_request(struct kvm_vcpu *vcpu)
 {
 	xfer_to_guest_mode_prepare();
@@ -3382,6 +3427,12 @@ int kvm_get_msr_common(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
 		msr_info->data = 0;
 		break;
 	case MSR_F15H_PERF_CTL0 ... MSR_F15H_PERF_CTR5:
+		if (kvm_pmu_is_valid_msr(vcpu, msr_info->index))
+			return kvm_pmu_get_msr(vcpu, msr_info);
+		if (!msr_info->host_initiated)
+			return 1;
+		msr_info->data = 0;
+		break;
 	case MSR_K7_EVNTSEL0 ... MSR_K7_EVNTSEL3:
 	case MSR_K7_PERFCTR0 ... MSR_K7_PERFCTR3:
 	case MSR_P6_PERFCTR0 ... MSR_P6_PERFCTR1:
@@ -3771,8 +3822,14 @@ int kvm_vm_ioctl_check_extension(struct kvm *kvm, long ext)
 	case KVM_CAP_X86_USER_SPACE_MSR:
 	case KVM_CAP_X86_MSR_FILTER:
 	case KVM_CAP_ENFORCE_PV_FEATURE_CPUID:
+#ifdef CONFIG_X86_SGX_KVM
+	case KVM_CAP_SGX_ATTRIBUTE:
+#endif
+	case KVM_CAP_VM_COPY_ENC_CONTEXT_FROM:
 		r = 1;
 		break;
+	case KVM_CAP_SET_GUEST_DEBUG2:
+		return KVM_GUESTDBG_VALID_MASK;
 #ifdef CONFIG_KVM_XEN
 	case KVM_CAP_XEN_HVM:
 		r = KVM_XEN_HVM_CONFIG_HYPERCALL_MSR |
@@ -4673,7 +4730,6 @@ static int kvm_vcpu_ioctl_enable_cap(struct kvm_vcpu *vcpu,
 			kvm_update_pv_runtime(vcpu);
 
 		return 0;
-
 	default:
 		return -EINVAL;
 	}
@@ -5355,6 +5411,28 @@ split_irqchip_unlock:
 			kvm->arch.bus_lock_detection_enabled = true;
 		r = 0;
 		break;
+#ifdef CONFIG_X86_SGX_KVM
+	case KVM_CAP_SGX_ATTRIBUTE: {
+		unsigned long allowed_attributes = 0;
+
+		r = sgx_set_attribute(&allowed_attributes, cap->args[0]);
+		if (r)
+			break;
+
+		/* KVM only supports the PROVISIONKEY privileged attribute. */
+		if ((allowed_attributes & SGX_ATTR_PROVISIONKEY) &&
+		    !(allowed_attributes & ~SGX_ATTR_PROVISIONKEY))
+			kvm->arch.sgx_provisioning_allowed = true;
+		else
+			r = -EINVAL;
+		break;
+	}
+#endif
+	case KVM_CAP_VM_COPY_ENC_CONTEXT_FROM:
+		r = -EINVAL;
+		if (kvm_x86_ops.vm_copy_enc_context_from)
+			r = kvm_x86_ops.vm_copy_enc_context_from(kvm, cap->args[0]);
+		return r;
 	default:
 		r = -EINVAL;
 		break;
@@ -5999,6 +6077,7 @@ gpa_t kvm_mmu_gva_to_gpa_read(struct kvm_vcpu *vcpu, gva_t gva,
 	u32 access = (static_call(kvm_x86_get_cpl)(vcpu) == 3) ? PFERR_USER_MASK : 0;
 	return vcpu->arch.walk_mmu->gva_to_gpa(vcpu, gva, access, exception);
 }
+EXPORT_SYMBOL_GPL(kvm_mmu_gva_to_gpa_read);
 
  gpa_t kvm_mmu_gva_to_gpa_fetch(struct kvm_vcpu *vcpu, gva_t gva,
 				struct x86_exception *exception)
@@ -6015,6 +6094,7 @@ gpa_t kvm_mmu_gva_to_gpa_write(struct kvm_vcpu *vcpu, gva_t gva,
 	access |= PFERR_WRITE_MASK;
 	return vcpu->arch.walk_mmu->gva_to_gpa(vcpu, gva, access, exception);
 }
+EXPORT_SYMBOL_GPL(kvm_mmu_gva_to_gpa_write);
 
 /* uses this to access any guest's mapped memory without checking CPL */
 gpa_t kvm_mmu_gva_to_gpa_system(struct kvm_vcpu *vcpu, gva_t gva,
@@ -6934,12 +7014,12 @@ static bool emulator_guest_has_fxsr(struct x86_emulate_ctxt *ctxt)
 
 static ulong emulator_read_gpr(struct x86_emulate_ctxt *ctxt, unsigned reg)
 {
-	return kvm_register_read(emul_to_vcpu(ctxt), reg);
+	return kvm_register_read_raw(emul_to_vcpu(ctxt), reg);
 }
 
 static void emulator_write_gpr(struct x86_emulate_ctxt *ctxt, unsigned reg, ulong val)
 {
-	kvm_register_write(emul_to_vcpu(ctxt), reg, val);
+	kvm_register_write_raw(emul_to_vcpu(ctxt), reg, val);
 }
 
 static void emulator_set_nmi_mask(struct x86_emulate_ctxt *ctxt, bool masked)
@@ -8043,9 +8123,6 @@ int kvm_arch_init(void *opaque)
 	if (r)
 		goto out_free_percpu;
 
-	kvm_mmu_set_mask_ptes(PT_USER_MASK, PT_ACCESSED_MASK,
-			PT_DIRTY_MASK, PT64_NX_MASK, 0,
-			PT_PRESENT_MASK, 0, sme_me_mask);
 	kvm_timer_init();
 
 	perf_register_guest_info_callbacks(&kvm_guest_cbs);
@@ -8205,21 +8282,35 @@ void kvm_apicv_init(struct kvm *kvm, bool enable)
 }
 EXPORT_SYMBOL_GPL(kvm_apicv_init);
 
-static void kvm_sched_yield(struct kvm *kvm, unsigned long dest_id)
+static void kvm_sched_yield(struct kvm_vcpu *vcpu, unsigned long dest_id)
 {
 	struct kvm_vcpu *target = NULL;
 	struct kvm_apic_map *map;
 
+	vcpu->stat.directed_yield_attempted++;
+
 	rcu_read_lock();
-	map = rcu_dereference(kvm->arch.apic_map);
+	map = rcu_dereference(vcpu->kvm->arch.apic_map);
 
 	if (likely(map) && dest_id <= map->max_apic_id && map->phys_map[dest_id])
 		target = map->phys_map[dest_id]->vcpu;
 
 	rcu_read_unlock();
 
-	if (target && READ_ONCE(target->ready))
-		kvm_vcpu_yield_to(target);
+	if (!target || !READ_ONCE(target->ready))
+		goto no_yield;
+
+	/* Ignore requests to yield to self */
+	if (vcpu == target)
+		goto no_yield;
+
+	if (kvm_vcpu_yield_to(target) <= 0)
+		goto no_yield;
+
+	vcpu->stat.directed_yield_successful++;
+
+no_yield:
+	return;
 }
 
 int kvm_emulate_hypercall(struct kvm_vcpu *vcpu)
@@ -8266,7 +8357,7 @@ int kvm_emulate_hypercall(struct kvm_vcpu *vcpu)
 			break;
 
 		kvm_pv_kick_cpu_op(vcpu->kvm, a0, a1);
-		kvm_sched_yield(vcpu->kvm, a1);
+		kvm_sched_yield(vcpu, a1);
 		ret = 0;
 		break;
 #ifdef CONFIG_X86_64
@@ -8284,7 +8375,7 @@ int kvm_emulate_hypercall(struct kvm_vcpu *vcpu)
 		if (!guest_pv_has(vcpu, KVM_FEATURE_PV_SCHED_YIELD))
 			break;
 
-		kvm_sched_yield(vcpu->kvm, a0);
+		kvm_sched_yield(vcpu, a0);
 		ret = 0;
 		break;
 	default:
@@ -8367,6 +8458,27 @@ static void update_cr8_intercept(struct kvm_vcpu *vcpu)
 	static_call(kvm_x86_update_cr8_intercept)(vcpu, tpr, max_irr);
 }
 
+
+int kvm_check_nested_events(struct kvm_vcpu *vcpu)
+{
+	if (WARN_ON_ONCE(!is_guest_mode(vcpu)))
+		return -EIO;
+
+	if (kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu)) {
+		kvm_x86_ops.nested_ops->triple_fault(vcpu);
+		return 1;
+	}
+
+	return kvm_x86_ops.nested_ops->check_events(vcpu);
+}
+
+static void kvm_inject_exception(struct kvm_vcpu *vcpu)
+{
+	if (vcpu->arch.exception.error_code && !is_protmode(vcpu))
+		vcpu->arch.exception.error_code = false;
+	static_call(kvm_x86_queue_exception)(vcpu);
+}
+
 static void inject_pending_event(struct kvm_vcpu *vcpu, bool *req_immediate_exit)
 {
 	int r;
@@ -8375,7 +8487,7 @@ static void inject_pending_event(struct kvm_vcpu *vcpu, bool *req_immediate_exit
 	/* try to reinject previous events if any */
 
 	if (vcpu->arch.exception.injected) {
-		static_call(kvm_x86_queue_exception)(vcpu);
+		kvm_inject_exception(vcpu);
 		can_inject = false;
 	}
 	/*
@@ -8412,7 +8524,7 @@ static void inject_pending_event(struct kvm_vcpu *vcpu, bool *req_immediate_exit
 	 * from L2 to L1.
 	 */
 	if (is_guest_mode(vcpu)) {
-		r = kvm_x86_ops.nested_ops->check_events(vcpu);
+		r = kvm_check_nested_events(vcpu);
 		if (r < 0)
 			goto busy;
 	}
@@ -8438,7 +8550,7 @@ static void inject_pending_event(struct kvm_vcpu *vcpu, bool *req_immediate_exit
 			}
 		}
 
-		static_call(kvm_x86_queue_exception)(vcpu);
+		kvm_inject_exception(vcpu);
 		can_inject = false;
 	}
 
@@ -8587,7 +8699,7 @@ static void enter_smm_save_state_32(struct kvm_vcpu *vcpu, char *buf)
 	put_smstate(u32, buf, 0x7ff0, kvm_rip_read(vcpu));
 
 	for (i = 0; i < 8; i++)
-		put_smstate(u32, buf, 0x7fd0 + i * 4, kvm_register_read(vcpu, i));
+		put_smstate(u32, buf, 0x7fd0 + i * 4, kvm_register_read_raw(vcpu, i));
 
 	kvm_get_dr(vcpu, 6, &val);
 	put_smstate(u32, buf, 0x7fcc, (u32)val);
@@ -8633,7 +8745,7 @@ static void enter_smm_save_state_64(struct kvm_vcpu *vcpu, char *buf)
 	int i;
 
 	for (i = 0; i < 16; i++)
-		put_smstate(u64, buf, 0x7ff8 - i * 8, kvm_register_read(vcpu, i));
+		put_smstate(u64, buf, 0x7ff8 - i * 8, kvm_register_read_raw(vcpu, i));
 
 	put_smstate(u64, buf, 0x7f78, kvm_rip_read(vcpu));
 	put_smstate(u32, buf, 0x7f70, kvm_get_rflags(vcpu));
@@ -8975,10 +9087,14 @@ static int vcpu_enter_guest(struct kvm_vcpu *vcpu)
 			goto out;
 		}
 		if (kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu)) {
-			vcpu->run->exit_reason = KVM_EXIT_SHUTDOWN;
-			vcpu->mmio_needed = 0;
-			r = 0;
-			goto out;
+			if (is_guest_mode(vcpu)) {
+				kvm_x86_ops.nested_ops->triple_fault(vcpu);
+			} else {
+				vcpu->run->exit_reason = KVM_EXIT_SHUTDOWN;
+				vcpu->mmio_needed = 0;
+				r = 0;
+				goto out;
+			}
 		}
 		if (kvm_check_request(KVM_REQ_APF_HALT, vcpu)) {
 			/* Page is swapped out. Do synthetic halt */
@@ -9276,7 +9392,7 @@ static inline int vcpu_block(struct kvm *kvm, struct kvm_vcpu *vcpu)
 static inline bool kvm_vcpu_running(struct kvm_vcpu *vcpu)
 {
 	if (is_guest_mode(vcpu))
-		kvm_x86_ops.nested_ops->check_events(vcpu);
+		kvm_check_nested_events(vcpu);
 
 	return (vcpu->arch.mp_state == KVM_MP_STATE_RUNNABLE &&
 		!vcpu->arch.apf.halted);
@@ -11002,6 +11118,14 @@ int kvm_arch_vcpu_runnable(struct kvm_vcpu *vcpu)
 	return kvm_vcpu_running(vcpu) || kvm_vcpu_has_events(vcpu);
 }
 
+bool kvm_arch_dy_has_pending_interrupt(struct kvm_vcpu *vcpu)
+{
+	if (vcpu->arch.apicv_active && static_call(kvm_x86_dy_apicv_has_pending_interrupt)(vcpu))
+		return true;
+
+	return false;
+}
+
 bool kvm_arch_dy_runnable(struct kvm_vcpu *vcpu)
 {
 	if (READ_ONCE(vcpu->arch.pv.pv_unhalted))
@@ -11012,14 +11136,14 @@ bool kvm_arch_dy_runnable(struct kvm_vcpu *vcpu)
 		 kvm_test_request(KVM_REQ_EVENT, vcpu))
 		return true;
 
-	if (vcpu->arch.apicv_active && static_call(kvm_x86_dy_apicv_has_pending_interrupt)(vcpu))
-		return true;
-
-	return false;
+	return kvm_arch_dy_has_pending_interrupt(vcpu);
 }
 
 bool kvm_arch_vcpu_in_kernel(struct kvm_vcpu *vcpu)
 {
+	if (vcpu->arch.guest_state_protected)
+		return true;
+
 	return vcpu->arch.preempted_in_kernel;
 }
 
@@ -11290,7 +11414,7 @@ bool kvm_arch_can_dequeue_async_page_present(struct kvm_vcpu *vcpu)
 	if (!kvm_pv_async_pf_enabled(vcpu))
 		return true;
 	else
-		return apf_pageready_slot_free(vcpu);
+		return kvm_lapic_enabled(vcpu) && apf_pageready_slot_free(vcpu);
 }
 
 void kvm_arch_start_assignment(struct kvm *kvm)
@@ -11539,7 +11663,7 @@ int kvm_handle_invpcid(struct kvm_vcpu *vcpu, unsigned long type, gva_t gva)
 
 		fallthrough;
 	case INVPCID_TYPE_ALL_INCL_GLOBAL:
-		kvm_mmu_unload(vcpu);
+		kvm_make_request(KVM_REQ_MMU_RELOAD, vcpu);
 		return kvm_skip_emulated_instruction(vcpu);
 
 	default:
diff --git a/arch/x86/kvm/x86.h b/arch/x86/kvm/x86.h
index 9035e34aa156..8ddd38146525 100644
--- a/arch/x86/kvm/x86.h
+++ b/arch/x86/kvm/x86.h
@@ -8,6 +8,14 @@
 #include "kvm_cache_regs.h"
 #include "kvm_emulate.h"
 
+#define KVM_NESTED_VMENTER_CONSISTENCY_CHECK(consistency_check)		\
+({									\
+	bool failed = (consistency_check);				\
+	if (failed)							\
+		trace_kvm_nested_vmenter_failed(#consistency_check, 0);	\
+	failed;								\
+})
+
 #define KVM_DEFAULT_PLE_GAP		128
 #define KVM_VMX_DEFAULT_PLE_WINDOW	4096
 #define KVM_DEFAULT_PLE_WINDOW_GROW	2
@@ -48,6 +56,8 @@ static inline unsigned int __shrink_ple_window(unsigned int val,
 
 #define MSR_IA32_CR_PAT_DEFAULT  0x0007040600070406ULL
 
+int kvm_check_nested_events(struct kvm_vcpu *vcpu);
+
 static inline void kvm_clear_exception_queue(struct kvm_vcpu *vcpu)
 {
 	vcpu->arch.exception.pending = false;
@@ -222,19 +232,19 @@ static inline bool vcpu_match_mmio_gpa(struct kvm_vcpu *vcpu, gpa_t gpa)
 	return false;
 }
 
-static inline unsigned long kvm_register_readl(struct kvm_vcpu *vcpu, int reg)
+static inline unsigned long kvm_register_read(struct kvm_vcpu *vcpu, int reg)
 {
-	unsigned long val = kvm_register_read(vcpu, reg);
+	unsigned long val = kvm_register_read_raw(vcpu, reg);
 
 	return is_64_bit_mode(vcpu) ? val : (u32)val;
 }
 
-static inline void kvm_register_writel(struct kvm_vcpu *vcpu,
+static inline void kvm_register_write(struct kvm_vcpu *vcpu,
 				       int reg, unsigned long val)
 {
 	if (!is_64_bit_mode(vcpu))
 		val = (u32)val;
-	return kvm_register_write(vcpu, reg, val);
+	return kvm_register_write_raw(vcpu, reg, val);
 }
 
 static inline bool kvm_check_has_quirk(struct kvm *kvm, u64 quirk)
diff --git a/arch/x86/mm/mem_encrypt.c b/arch/x86/mm/mem_encrypt.c
index f633f9e23b8f..ff08dc463634 100644
--- a/arch/x86/mm/mem_encrypt.c
+++ b/arch/x86/mm/mem_encrypt.c
@@ -45,8 +45,6 @@ EXPORT_SYMBOL(sme_me_mask);
 DEFINE_STATIC_KEY_FALSE(sev_enable_key);
 EXPORT_SYMBOL_GPL(sev_enable_key);
 
-bool sev_enabled __section(".data");
-
 /* Buffer used for early in-place encryption by BSP, no locking needed */
 static char sme_early_buffer[PAGE_SIZE] __initdata __aligned(PAGE_SIZE);
 
@@ -374,14 +372,14 @@ int __init early_set_memory_encrypted(unsigned long vaddr, unsigned long size)
  * up under SME the trampoline area cannot be encrypted, whereas under SEV
  * the trampoline area must be encrypted.
  */
-bool sme_active(void)
+bool sev_active(void)
 {
-	return sme_me_mask && !sev_enabled;
+	return sev_status & MSR_AMD64_SEV_ENABLED;
 }
 
-bool sev_active(void)
+bool sme_active(void)
 {
-	return sev_status & MSR_AMD64_SEV_ENABLED;
+	return sme_me_mask && !sev_active();
 }
 EXPORT_SYMBOL_GPL(sev_active);
 
diff --git a/arch/x86/mm/mem_encrypt_identity.c b/arch/x86/mm/mem_encrypt_identity.c
index a19374d26101..04aba7e80a36 100644
--- a/arch/x86/mm/mem_encrypt_identity.c
+++ b/arch/x86/mm/mem_encrypt_identity.c
@@ -548,7 +548,6 @@ void __init sme_enable(struct boot_params *bp)
 	} else {
 		/* SEV state cannot be controlled by a command line option */
 		sme_me_mask = me_mask;
-		sev_enabled = true;
 		physical_mask &= ~sme_me_mask;
 		return;
 	}