summary refs log tree commit diff
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2022-06-01 13:49:15 -0700
committerLinus Torvalds <torvalds@linux-foundation.org>2022-06-01 13:49:15 -0700
commit176882156ae6d63a81fe7f01ea6fe65ab6b52105 (patch)
tree6c5fffb2f2ad707d2abc0b8479bf703483b08f8f
parent8171acb8bc9b33f3ed827f0615b24f7a06495cd0 (diff)
parent421cfe6596f6cb316991c02bf30a93bd81092853 (diff)
downloadlinux-176882156ae6d63a81fe7f01ea6fe65ab6b52105.tar.gz
Merge tag 'vfio-v5.19-rc1' of https://github.com/awilliam/linux-vfio
Pull vfio updates from Alex Williamson:

 - Improvements to mlx5 vfio-pci variant driver, including support for
   parallel migration per PF (Yishai Hadas)

 - Remove redundant iommu_present() check (Robin Murphy)

 - Ongoing refactoring to consolidate the VFIO driver facing API to use
   vfio_device (Jason Gunthorpe)

 - Use drvdata to store vfio_device among all vfio-pci and variant
   drivers (Jason Gunthorpe)

 - Remove redundant code now that IOMMU core manages group DMA ownership
   (Jason Gunthorpe)

 - Remove vfio_group from external API handling struct file ownership
   (Jason Gunthorpe)

 - Correct typo in uapi comments (Thomas Huth)

 - Fix coccicheck detected deadlock (Wan Jiabing)

 - Use rwsem to remove races and simplify code around container and kvm
   association to groups (Jason Gunthorpe)

 - Harden access to devices in low power states and use runtime PM to
   enable d3cold support for unused devices (Abhishek Sahu)

 - Fix dma_owner handling of fake IOMMU groups (Jason Gunthorpe)

 - Set driver_managed_dma on vfio-pci variant drivers (Jason Gunthorpe)

 - Pass KVM pointer directly rather than via notifier (Matthew Rosato)

* tag 'vfio-v5.19-rc1' of https://github.com/awilliam/linux-vfio: (38 commits)
  vfio: remove VFIO_GROUP_NOTIFY_SET_KVM
  vfio/pci: Add driver_managed_dma to the new vfio_pci drivers
  vfio: Do not manipulate iommu dma_owner for fake iommu groups
  vfio/pci: Move the unused device into low power state with runtime PM
  vfio/pci: Virtualize PME related registers bits and initialize to zero
  vfio/pci: Change the PF power state to D0 before enabling VFs
  vfio/pci: Invalidate mmaps and block the access in D3hot power state
  vfio: Change struct vfio_group::container_users to a non-atomic int
  vfio: Simplify the life cycle of the group FD
  vfio: Fully lock struct vfio_group::container
  vfio: Split up vfio_group_get_device_fd()
  vfio: Change struct vfio_group::opened from an atomic to bool
  vfio: Add missing locking for struct vfio_group::kvm
  kvm/vfio: Fix potential deadlock problem in vfio
  include/uapi/linux/vfio.h: Fix trivial typo - _IORW should be _IOWR instead
  vfio/pci: Use the struct file as the handle not the vfio_group
  kvm/vfio: Remove vfio_group from kvm
  vfio: Change vfio_group_set_kvm() to vfio_file_set_kvm()
  vfio: Change vfio_external_check_extension() to vfio_file_enforced_coherent()
  vfio: Remove vfio_external_group_match_file()
  ...
-rw-r--r--Documentation/driver-api/vfio-mediated-device.rst4
-rw-r--r--drivers/gpu/drm/i915/gvt/gtt.c4
-rw-r--r--drivers/gpu/drm/i915/gvt/gvt.h8
-rw-r--r--drivers/gpu/drm/i915/gvt/kvmgt.c115
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/sriov.c65
-rw-r--r--drivers/s390/cio/vfio_ccw_cp.c47
-rw-r--r--drivers/s390/cio/vfio_ccw_cp.h4
-rw-r--r--drivers/s390/cio/vfio_ccw_fsm.c3
-rw-r--r--drivers/s390/cio/vfio_ccw_ops.c7
-rw-r--r--drivers/s390/crypto/vfio_ap_ops.c50
-rw-r--r--drivers/s390/crypto/vfio_ap_private.h3
-rw-r--r--drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c16
-rw-r--r--drivers/vfio/pci/mlx5/cmd.c236
-rw-r--r--drivers/vfio/pci/mlx5/cmd.h52
-rw-r--r--drivers/vfio/pci/mlx5/main.c136
-rw-r--r--drivers/vfio/pci/vfio_pci.c6
-rw-r--r--drivers/vfio/pci/vfio_pci_config.c56
-rw-r--r--drivers/vfio/pci/vfio_pci_core.c254
-rw-r--r--drivers/vfio/vfio.c781
-rw-r--r--include/linux/mlx5/driver.h12
-rw-r--r--include/linux/vfio.h44
-rw-r--r--include/linux/vfio_pci_core.h3
-rw-r--r--include/uapi/linux/vfio.h4
-rw-r--r--virt/kvm/vfio.c329
24 files changed, 1095 insertions, 1144 deletions
diff --git a/Documentation/driver-api/vfio-mediated-device.rst b/Documentation/driver-api/vfio-mediated-device.rst
index 784bbeb22adc..eded8719180f 100644
--- a/Documentation/driver-api/vfio-mediated-device.rst
+++ b/Documentation/driver-api/vfio-mediated-device.rst
@@ -262,10 +262,10 @@ Translation APIs for Mediated Devices
 The following APIs are provided for translating user pfn to host pfn in a VFIO
 driver::
 
-	extern int vfio_pin_pages(struct device *dev, unsigned long *user_pfn,
+	int vfio_pin_pages(struct vfio_device *device, unsigned long *user_pfn,
 				  int npage, int prot, unsigned long *phys_pfn);
 
-	extern int vfio_unpin_pages(struct device *dev, unsigned long *user_pfn,
+	int vfio_unpin_pages(struct vfio_device *device, unsigned long *user_pfn,
 				    int npage);
 
 These functions call back into the back-end IOMMU module by using the pin_pages
diff --git a/drivers/gpu/drm/i915/gvt/gtt.c b/drivers/gpu/drm/i915/gvt/gtt.c
index 9c5cc2800975..b4f69364f9a1 100644
--- a/drivers/gpu/drm/i915/gvt/gtt.c
+++ b/drivers/gpu/drm/i915/gvt/gtt.c
@@ -51,7 +51,7 @@ static int preallocated_oos_pages = 8192;
 
 static bool intel_gvt_is_valid_gfn(struct intel_vgpu *vgpu, unsigned long gfn)
 {
-	struct kvm *kvm = vgpu->kvm;
+	struct kvm *kvm = vgpu->vfio_device.kvm;
 	int idx;
 	bool ret;
 
@@ -1185,7 +1185,7 @@ static int is_2MB_gtt_possible(struct intel_vgpu *vgpu,
 
 	if (!vgpu->attached)
 		return -EINVAL;
-	pfn = gfn_to_pfn(vgpu->kvm, ops->get_pfn(entry));
+	pfn = gfn_to_pfn(vgpu->vfio_device.kvm, ops->get_pfn(entry));
 	if (is_error_noslot_pfn(pfn))
 		return -EINVAL;
 	return PageTransHuge(pfn_to_page(pfn));
diff --git a/drivers/gpu/drm/i915/gvt/gvt.h b/drivers/gpu/drm/i915/gvt/gvt.h
index 03ecffc2ba56..aee1a45da74b 100644
--- a/drivers/gpu/drm/i915/gvt/gvt.h
+++ b/drivers/gpu/drm/i915/gvt/gvt.h
@@ -227,11 +227,7 @@ struct intel_vgpu {
 	struct mutex cache_lock;
 
 	struct notifier_block iommu_notifier;
-	struct notifier_block group_notifier;
-	struct kvm *kvm;
-	struct work_struct release_work;
 	atomic_t released;
-	struct vfio_group *vfio_group;
 
 	struct kvm_page_track_notifier_node track_node;
 #define NR_BKT (1 << 18)
@@ -732,7 +728,7 @@ static inline int intel_gvt_read_gpa(struct intel_vgpu *vgpu, unsigned long gpa,
 {
 	if (!vgpu->attached)
 		return -ESRCH;
-	return vfio_dma_rw(vgpu->vfio_group, gpa, buf, len, false);
+	return vfio_dma_rw(&vgpu->vfio_device, gpa, buf, len, false);
 }
 
 /**
@@ -750,7 +746,7 @@ static inline int intel_gvt_write_gpa(struct intel_vgpu *vgpu,
 {
 	if (!vgpu->attached)
 		return -ESRCH;
-	return vfio_dma_rw(vgpu->vfio_group, gpa, buf, len, true);
+	return vfio_dma_rw(&vgpu->vfio_device, gpa, buf, len, true);
 }
 
 void intel_gvt_debugfs_remove_vgpu(struct intel_vgpu *vgpu);
diff --git a/drivers/gpu/drm/i915/gvt/kvmgt.c b/drivers/gpu/drm/i915/gvt/kvmgt.c
index 0787ba5c301f..e2f6c56ab342 100644
--- a/drivers/gpu/drm/i915/gvt/kvmgt.c
+++ b/drivers/gpu/drm/i915/gvt/kvmgt.c
@@ -228,8 +228,6 @@ static void intel_gvt_cleanup_vgpu_type_groups(struct intel_gvt *gvt)
 	}
 }
 
-static void intel_vgpu_release_work(struct work_struct *work);
-
 static void gvt_unpin_guest_page(struct intel_vgpu *vgpu, unsigned long gfn,
 		unsigned long size)
 {
@@ -243,7 +241,7 @@ static void gvt_unpin_guest_page(struct intel_vgpu *vgpu, unsigned long gfn,
 	for (npage = 0; npage < total_pages; npage++) {
 		unsigned long cur_gfn = gfn + npage;
 
-		ret = vfio_group_unpin_pages(vgpu->vfio_group, &cur_gfn, 1);
+		ret = vfio_unpin_pages(&vgpu->vfio_device, &cur_gfn, 1);
 		drm_WARN_ON(&i915->drm, ret != 1);
 	}
 }
@@ -266,8 +264,8 @@ static int gvt_pin_guest_page(struct intel_vgpu *vgpu, unsigned long gfn,
 		unsigned long cur_gfn = gfn + npage;
 		unsigned long pfn;
 
-		ret = vfio_group_pin_pages(vgpu->vfio_group, &cur_gfn, 1,
-					   IOMMU_READ | IOMMU_WRITE, &pfn);
+		ret = vfio_pin_pages(&vgpu->vfio_device, &cur_gfn, 1,
+				     IOMMU_READ | IOMMU_WRITE, &pfn);
 		if (ret != 1) {
 			gvt_vgpu_err("vfio_pin_pages failed for gfn 0x%lx, ret %d\n",
 				     cur_gfn, ret);
@@ -761,23 +759,6 @@ static int intel_vgpu_iommu_notifier(struct notifier_block *nb,
 	return NOTIFY_OK;
 }
 
-static int intel_vgpu_group_notifier(struct notifier_block *nb,
-				     unsigned long action, void *data)
-{
-	struct intel_vgpu *vgpu =
-		container_of(nb, struct intel_vgpu, group_notifier);
-
-	/* the only action we care about */
-	if (action == VFIO_GROUP_NOTIFY_SET_KVM) {
-		vgpu->kvm = data;
-
-		if (!data)
-			schedule_work(&vgpu->release_work);
-	}
-
-	return NOTIFY_OK;
-}
-
 static bool __kvmgt_vgpu_exist(struct intel_vgpu *vgpu)
 {
 	struct intel_vgpu *itr;
@@ -789,7 +770,7 @@ static bool __kvmgt_vgpu_exist(struct intel_vgpu *vgpu)
 		if (!itr->attached)
 			continue;
 
-		if (vgpu->kvm == itr->kvm) {
+		if (vgpu->vfio_device.kvm == itr->vfio_device.kvm) {
 			ret = true;
 			goto out;
 		}
@@ -804,61 +785,44 @@ static int intel_vgpu_open_device(struct vfio_device *vfio_dev)
 	struct intel_vgpu *vgpu = vfio_dev_to_vgpu(vfio_dev);
 	unsigned long events;
 	int ret;
-	struct vfio_group *vfio_group;
 
 	vgpu->iommu_notifier.notifier_call = intel_vgpu_iommu_notifier;
-	vgpu->group_notifier.notifier_call = intel_vgpu_group_notifier;
 
 	events = VFIO_IOMMU_NOTIFY_DMA_UNMAP;
-	ret = vfio_register_notifier(vfio_dev->dev, VFIO_IOMMU_NOTIFY, &events,
-				&vgpu->iommu_notifier);
+	ret = vfio_register_notifier(vfio_dev, VFIO_IOMMU_NOTIFY, &events,
+				     &vgpu->iommu_notifier);
 	if (ret != 0) {
 		gvt_vgpu_err("vfio_register_notifier for iommu failed: %d\n",
 			ret);
 		goto out;
 	}
 
-	events = VFIO_GROUP_NOTIFY_SET_KVM;
-	ret = vfio_register_notifier(vfio_dev->dev, VFIO_GROUP_NOTIFY, &events,
-				&vgpu->group_notifier);
-	if (ret != 0) {
-		gvt_vgpu_err("vfio_register_notifier for group failed: %d\n",
-			ret);
-		goto undo_iommu;
-	}
-
-	vfio_group =
-		vfio_group_get_external_user_from_dev(vgpu->vfio_device.dev);
-	if (IS_ERR_OR_NULL(vfio_group)) {
-		ret = !vfio_group ? -EFAULT : PTR_ERR(vfio_group);
-		gvt_vgpu_err("vfio_group_get_external_user_from_dev failed\n");
-		goto undo_register;
-	}
-	vgpu->vfio_group = vfio_group;
-
 	ret = -EEXIST;
 	if (vgpu->attached)
-		goto undo_group;
+		goto undo_iommu;
 
 	ret = -ESRCH;
-	if (!vgpu->kvm || vgpu->kvm->mm != current->mm) {
+	if (!vgpu->vfio_device.kvm ||
+	    vgpu->vfio_device.kvm->mm != current->mm) {
 		gvt_vgpu_err("KVM is required to use Intel vGPU\n");
-		goto undo_group;
+		goto undo_iommu;
 	}
 
+	kvm_get_kvm(vgpu->vfio_device.kvm);
+
 	ret = -EEXIST;
 	if (__kvmgt_vgpu_exist(vgpu))
-		goto undo_group;
+		goto undo_iommu;
 
 	vgpu->attached = true;
-	kvm_get_kvm(vgpu->kvm);
 
 	kvmgt_protect_table_init(vgpu);
 	gvt_cache_init(vgpu);
 
 	vgpu->track_node.track_write = kvmgt_page_track_write;
 	vgpu->track_node.track_flush_slot = kvmgt_page_track_flush_slot;
-	kvm_page_track_register_notifier(vgpu->kvm, &vgpu->track_node);
+	kvm_page_track_register_notifier(vgpu->vfio_device.kvm,
+					 &vgpu->track_node);
 
 	debugfs_create_ulong(KVMGT_DEBUGFS_FILENAME, 0444, vgpu->debugfs,
 			     &vgpu->nr_cache_entries);
@@ -868,17 +832,9 @@ static int intel_vgpu_open_device(struct vfio_device *vfio_dev)
 	atomic_set(&vgpu->released, 0);
 	return 0;
 
-undo_group:
-	vfio_group_put_external_user(vgpu->vfio_group);
-	vgpu->vfio_group = NULL;
-
-undo_register:
-	vfio_unregister_notifier(vfio_dev->dev, VFIO_GROUP_NOTIFY,
-					&vgpu->group_notifier);
-
 undo_iommu:
-	vfio_unregister_notifier(vfio_dev->dev, VFIO_IOMMU_NOTIFY,
-					&vgpu->iommu_notifier);
+	vfio_unregister_notifier(vfio_dev, VFIO_IOMMU_NOTIFY,
+				 &vgpu->iommu_notifier);
 out:
 	return ret;
 }
@@ -894,8 +850,9 @@ static void intel_vgpu_release_msi_eventfd_ctx(struct intel_vgpu *vgpu)
 	}
 }
 
-static void __intel_vgpu_release(struct intel_vgpu *vgpu)
+static void intel_vgpu_close_device(struct vfio_device *vfio_dev)
 {
+	struct intel_vgpu *vgpu = vfio_dev_to_vgpu(vfio_dev);
 	struct drm_i915_private *i915 = vgpu->gvt->gt->i915;
 	int ret;
 
@@ -907,41 +864,24 @@ static void __intel_vgpu_release(struct intel_vgpu *vgpu)
 
 	intel_gvt_release_vgpu(vgpu);
 
-	ret = vfio_unregister_notifier(vgpu->vfio_device.dev, VFIO_IOMMU_NOTIFY,
-					&vgpu->iommu_notifier);
+	ret = vfio_unregister_notifier(&vgpu->vfio_device, VFIO_IOMMU_NOTIFY,
+				       &vgpu->iommu_notifier);
 	drm_WARN(&i915->drm, ret,
 		 "vfio_unregister_notifier for iommu failed: %d\n", ret);
 
-	ret = vfio_unregister_notifier(vgpu->vfio_device.dev, VFIO_GROUP_NOTIFY,
-					&vgpu->group_notifier);
-	drm_WARN(&i915->drm, ret,
-		 "vfio_unregister_notifier for group failed: %d\n", ret);
-
 	debugfs_remove(debugfs_lookup(KVMGT_DEBUGFS_FILENAME, vgpu->debugfs));
 
-	kvm_page_track_unregister_notifier(vgpu->kvm, &vgpu->track_node);
-	kvm_put_kvm(vgpu->kvm);
+	kvm_page_track_unregister_notifier(vgpu->vfio_device.kvm,
+					   &vgpu->track_node);
 	kvmgt_protect_table_destroy(vgpu);
 	gvt_cache_destroy(vgpu);
 
 	intel_vgpu_release_msi_eventfd_ctx(vgpu);
-	vfio_group_put_external_user(vgpu->vfio_group);
 
-	vgpu->kvm = NULL;
 	vgpu->attached = false;
-}
-
-static void intel_vgpu_close_device(struct vfio_device *vfio_dev)
-{
-	__intel_vgpu_release(vfio_dev_to_vgpu(vfio_dev));
-}
-
-static void intel_vgpu_release_work(struct work_struct *work)
-{
-	struct intel_vgpu *vgpu =
-		container_of(work, struct intel_vgpu, release_work);
 
-	__intel_vgpu_release(vgpu);
+	if (vgpu->vfio_device.kvm)
+		kvm_put_kvm(vgpu->vfio_device.kvm);
 }
 
 static u64 intel_vgpu_get_bar_addr(struct intel_vgpu *vgpu, int bar)
@@ -1690,7 +1630,6 @@ static int intel_vgpu_probe(struct mdev_device *mdev)
 		return PTR_ERR(vgpu);
 	}
 
-	INIT_WORK(&vgpu->release_work, intel_vgpu_release_work);
 	vfio_init_group_dev(&vgpu->vfio_device, &mdev->dev,
 			    &intel_vgpu_dev_ops);
 
@@ -1728,7 +1667,7 @@ static struct mdev_driver intel_vgpu_mdev_driver = {
 
 int intel_gvt_page_track_add(struct intel_vgpu *info, u64 gfn)
 {
-	struct kvm *kvm = info->kvm;
+	struct kvm *kvm = info->vfio_device.kvm;
 	struct kvm_memory_slot *slot;
 	int idx;
 
@@ -1758,7 +1697,7 @@ out:
 
 int intel_gvt_page_track_remove(struct intel_vgpu *info, u64 gfn)
 {
-	struct kvm *kvm = info->kvm;
+	struct kvm *kvm = info->vfio_device.kvm;
 	struct kvm_memory_slot *slot;
 	int idx;
 
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/sriov.c b/drivers/net/ethernet/mellanox/mlx5/core/sriov.c
index 887ee0f729d1..2935614f6fa9 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/sriov.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/sriov.c
@@ -87,6 +87,11 @@ static int mlx5_device_enable_sriov(struct mlx5_core_dev *dev, int num_vfs)
 enable_vfs_hca:
 	num_msix_count = mlx5_get_default_msix_vec_count(dev, num_vfs);
 	for (vf = 0; vf < num_vfs; vf++) {
+		/* Notify the VF before its enablement to let it set
+		 * some stuff.
+		 */
+		blocking_notifier_call_chain(&sriov->vfs_ctx[vf].notifier,
+					     MLX5_PF_NOTIFY_ENABLE_VF, dev);
 		err = mlx5_core_enable_hca(dev, vf + 1);
 		if (err) {
 			mlx5_core_warn(dev, "failed to enable VF %d (%d)\n", vf, err);
@@ -127,6 +132,11 @@ mlx5_device_disable_sriov(struct mlx5_core_dev *dev, int num_vfs, bool clear_vf)
 	for (vf = num_vfs - 1; vf >= 0; vf--) {
 		if (!sriov->vfs_ctx[vf].enabled)
 			continue;
+		/* Notify the VF before its disablement to let it clean
+		 * some resources.
+		 */
+		blocking_notifier_call_chain(&sriov->vfs_ctx[vf].notifier,
+					     MLX5_PF_NOTIFY_DISABLE_VF, dev);
 		err = mlx5_core_disable_hca(dev, vf + 1);
 		if (err) {
 			mlx5_core_warn(dev, "failed to disable VF %d\n", vf);
@@ -257,7 +267,7 @@ int mlx5_sriov_init(struct mlx5_core_dev *dev)
 {
 	struct mlx5_core_sriov *sriov = &dev->priv.sriov;
 	struct pci_dev *pdev = dev->pdev;
-	int total_vfs;
+	int total_vfs, i;
 
 	if (!mlx5_core_is_pf(dev))
 		return 0;
@@ -269,6 +279,9 @@ int mlx5_sriov_init(struct mlx5_core_dev *dev)
 	if (!sriov->vfs_ctx)
 		return -ENOMEM;
 
+	for (i = 0; i < total_vfs; i++)
+		BLOCKING_INIT_NOTIFIER_HEAD(&sriov->vfs_ctx[i].notifier);
+
 	return 0;
 }
 
@@ -281,3 +294,53 @@ void mlx5_sriov_cleanup(struct mlx5_core_dev *dev)
 
 	kfree(sriov->vfs_ctx);
 }
+
+/**
+ * mlx5_sriov_blocking_notifier_unregister - Unregister a VF from
+ * a notification block chain.
+ *
+ * @mdev: The mlx5 core device.
+ * @vf_id: The VF id.
+ * @nb: The notifier block to be unregistered.
+ */
+void mlx5_sriov_blocking_notifier_unregister(struct mlx5_core_dev *mdev,
+					     int vf_id,
+					     struct notifier_block *nb)
+{
+	struct mlx5_vf_context *vfs_ctx;
+	struct mlx5_core_sriov *sriov;
+
+	sriov = &mdev->priv.sriov;
+	if (WARN_ON(vf_id < 0 || vf_id >= sriov->num_vfs))
+		return;
+
+	vfs_ctx = &sriov->vfs_ctx[vf_id];
+	blocking_notifier_chain_unregister(&vfs_ctx->notifier, nb);
+}
+EXPORT_SYMBOL(mlx5_sriov_blocking_notifier_unregister);
+
+/**
+ * mlx5_sriov_blocking_notifier_register - Register a VF notification
+ * block chain.
+ *
+ * @mdev: The mlx5 core device.
+ * @vf_id: The VF id.
+ * @nb: The notifier block to be called upon the VF events.
+ *
+ * Returns 0 on success or an error code.
+ */
+int mlx5_sriov_blocking_notifier_register(struct mlx5_core_dev *mdev,
+					  int vf_id,
+					  struct notifier_block *nb)
+{
+	struct mlx5_vf_context *vfs_ctx;
+	struct mlx5_core_sriov *sriov;
+
+	sriov = &mdev->priv.sriov;
+	if (vf_id < 0 || vf_id >= sriov->num_vfs)
+		return -EINVAL;
+
+	vfs_ctx = &sriov->vfs_ctx[vf_id];
+	return blocking_notifier_chain_register(&vfs_ctx->notifier, nb);
+}
+EXPORT_SYMBOL(mlx5_sriov_blocking_notifier_register);
diff --git a/drivers/s390/cio/vfio_ccw_cp.c b/drivers/s390/cio/vfio_ccw_cp.c
index 8d1b2771c1aa..0c2be9421ab7 100644
--- a/drivers/s390/cio/vfio_ccw_cp.c
+++ b/drivers/s390/cio/vfio_ccw_cp.c
@@ -16,6 +16,7 @@
 #include <asm/idals.h>
 
 #include "vfio_ccw_cp.h"
+#include "vfio_ccw_private.h"
 
 struct pfn_array {
 	/* Starting guest physical I/O address. */
@@ -98,17 +99,17 @@ static int pfn_array_alloc(struct pfn_array *pa, u64 iova, unsigned int len)
  * If the pin request partially succeeds, or fails completely,
  * all pages are left unpinned and a negative error value is returned.
  */
-static int pfn_array_pin(struct pfn_array *pa, struct device *mdev)
+static int pfn_array_pin(struct pfn_array *pa, struct vfio_device *vdev)
 {
 	int ret = 0;
 
-	ret = vfio_pin_pages(mdev, pa->pa_iova_pfn, pa->pa_nr,
+	ret = vfio_pin_pages(vdev, pa->pa_iova_pfn, pa->pa_nr,
 			     IOMMU_READ | IOMMU_WRITE, pa->pa_pfn);
 
 	if (ret < 0) {
 		goto err_out;
 	} else if (ret > 0 && ret != pa->pa_nr) {
-		vfio_unpin_pages(mdev, pa->pa_iova_pfn, ret);
+		vfio_unpin_pages(vdev, pa->pa_iova_pfn, ret);
 		ret = -EINVAL;
 		goto err_out;
 	}
@@ -122,11 +123,11 @@ err_out:
 }
 
 /* Unpin the pages before releasing the memory. */
-static void pfn_array_unpin_free(struct pfn_array *pa, struct device *mdev)
+static void pfn_array_unpin_free(struct pfn_array *pa, struct vfio_device *vdev)
 {
 	/* Only unpin if any pages were pinned to begin with */
 	if (pa->pa_nr)
-		vfio_unpin_pages(mdev, pa->pa_iova_pfn, pa->pa_nr);
+		vfio_unpin_pages(vdev, pa->pa_iova_pfn, pa->pa_nr);
 	pa->pa_nr = 0;
 	kfree(pa->pa_iova_pfn);
 }
@@ -190,8 +191,7 @@ static void convert_ccw0_to_ccw1(struct ccw1 *source, unsigned long len)
  * Within the domain (@mdev), copy @n bytes from a guest physical
  * address (@iova) to a host physical address (@to).
  */
-static long copy_from_iova(struct device *mdev,
-			   void *to, u64 iova,
+static long copy_from_iova(struct vfio_device *vdev, void *to, u64 iova,
 			   unsigned long n)
 {
 	struct pfn_array pa = {0};
@@ -203,9 +203,9 @@ static long copy_from_iova(struct device *mdev,
 	if (ret < 0)
 		return ret;
 
-	ret = pfn_array_pin(&pa, mdev);
+	ret = pfn_array_pin(&pa, vdev);
 	if (ret < 0) {
-		pfn_array_unpin_free(&pa, mdev);
+		pfn_array_unpin_free(&pa, vdev);
 		return ret;
 	}
 
@@ -226,7 +226,7 @@ static long copy_from_iova(struct device *mdev,
 			break;
 	}
 
-	pfn_array_unpin_free(&pa, mdev);
+	pfn_array_unpin_free(&pa, vdev);
 
 	return l;
 }
@@ -423,11 +423,13 @@ static int ccwchain_loop_tic(struct ccwchain *chain,
 
 static int ccwchain_handle_ccw(u32 cda, struct channel_program *cp)
 {
+	struct vfio_device *vdev =
+		&container_of(cp, struct vfio_ccw_private, cp)->vdev;
 	struct ccwchain *chain;
 	int len, ret;
 
 	/* Copy 2K (the most we support today) of possible CCWs */
-	len = copy_from_iova(cp->mdev, cp->guest_cp, cda,
+	len = copy_from_iova(vdev, cp->guest_cp, cda,
 			     CCWCHAIN_LEN_MAX * sizeof(struct ccw1));
 	if (len)
 		return len;
@@ -508,6 +510,8 @@ static int ccwchain_fetch_direct(struct ccwchain *chain,
 				 int idx,
 				 struct channel_program *cp)
 {
+	struct vfio_device *vdev =
+		&container_of(cp, struct vfio_ccw_private, cp)->vdev;
 	struct ccw1 *ccw;
 	struct pfn_array *pa;
 	u64 iova;
@@ -526,7 +530,7 @@ static int ccwchain_fetch_direct(struct ccwchain *chain,
 	if (ccw_is_idal(ccw)) {
 		/* Read first IDAW to see if it's 4K-aligned or not. */
 		/* All subsequent IDAws will be 4K-aligned. */
-		ret = copy_from_iova(cp->mdev, &iova, ccw->cda, sizeof(iova));
+		ret = copy_from_iova(vdev, &iova, ccw->cda, sizeof(iova));
 		if (ret)
 			return ret;
 	} else {
@@ -555,7 +559,7 @@ static int ccwchain_fetch_direct(struct ccwchain *chain,
 
 	if (ccw_is_idal(ccw)) {
 		/* Copy guest IDAL into host IDAL */
-		ret = copy_from_iova(cp->mdev, idaws, ccw->cda, idal_len);
+		ret = copy_from_iova(vdev, idaws, ccw->cda, idal_len);
 		if (ret)
 			goto out_unpin;
 
@@ -574,7 +578,7 @@ static int ccwchain_fetch_direct(struct ccwchain *chain,
 	}
 
 	if (ccw_does_data_transfer(ccw)) {
-		ret = pfn_array_pin(pa, cp->mdev);
+		ret = pfn_array_pin(pa, vdev);
 		if (ret < 0)
 			goto out_unpin;
 	} else {
@@ -590,7 +594,7 @@ static int ccwchain_fetch_direct(struct ccwchain *chain,
 	return 0;
 
 out_unpin:
-	pfn_array_unpin_free(pa, cp->mdev);
+	pfn_array_unpin_free(pa, vdev);
 out_free_idaws:
 	kfree(idaws);
 out_init:
@@ -632,8 +636,10 @@ static int ccwchain_fetch_one(struct ccwchain *chain,
  * Returns:
  *   %0 on success and a negative error value on failure.
  */
-int cp_init(struct channel_program *cp, struct device *mdev, union orb *orb)
+int cp_init(struct channel_program *cp, union orb *orb)
 {
+	struct vfio_device *vdev =
+		&container_of(cp, struct vfio_ccw_private, cp)->vdev;
 	/* custom ratelimit used to avoid flood during guest IPL */
 	static DEFINE_RATELIMIT_STATE(ratelimit_state, 5 * HZ, 1);
 	int ret;
@@ -650,11 +656,12 @@ int cp_init(struct channel_program *cp, struct device *mdev, union orb *orb)
 	 * the problem if something does break.
 	 */
 	if (!orb->cmd.pfch && __ratelimit(&ratelimit_state))
-		dev_warn(mdev, "Prefetching channel program even though prefetch not specified in ORB");
+		dev_warn(
+			vdev->dev,
+			"Prefetching channel program even though prefetch not specified in ORB");
 
 	INIT_LIST_HEAD(&cp->ccwchain_list);
 	memcpy(&cp->orb, orb, sizeof(*orb));
-	cp->mdev = mdev;
 
 	/* Build a ccwchain for the first CCW segment */
 	ret = ccwchain_handle_ccw(orb->cmd.cpa, cp);
@@ -682,6 +689,8 @@ int cp_init(struct channel_program *cp, struct device *mdev, union orb *orb)
  */
 void cp_free(struct channel_program *cp)
 {
+	struct vfio_device *vdev =
+		&container_of(cp, struct vfio_ccw_private, cp)->vdev;
 	struct ccwchain *chain, *temp;
 	int i;
 
@@ -691,7 +700,7 @@ void cp_free(struct channel_program *cp)
 	cp->initialized = false;
 	list_for_each_entry_safe(chain, temp, &cp->ccwchain_list, next) {
 		for (i = 0; i < chain->ch_len; i++) {
-			pfn_array_unpin_free(chain->ch_pa + i, cp->mdev);
+			pfn_array_unpin_free(chain->ch_pa + i, vdev);
 			ccwchain_cda_free(chain, i);
 		}
 		ccwchain_free(chain);
diff --git a/drivers/s390/cio/vfio_ccw_cp.h b/drivers/s390/cio/vfio_ccw_cp.h
index ba31240ce965..e4c436199b4c 100644
--- a/drivers/s390/cio/vfio_ccw_cp.h
+++ b/drivers/s390/cio/vfio_ccw_cp.h
@@ -37,13 +37,11 @@
 struct channel_program {
 	struct list_head ccwchain_list;
 	union orb orb;
-	struct device *mdev;
 	bool initialized;
 	struct ccw1 *guest_cp;
 };
 
-extern int cp_init(struct channel_program *cp, struct device *mdev,
-		   union orb *orb);
+extern int cp_init(struct channel_program *cp, union orb *orb);
 extern void cp_free(struct channel_program *cp);
 extern int cp_prefetch(struct channel_program *cp);
 extern union orb *cp_get_orb(struct channel_program *cp, u32 intparm, u8 lpm);
diff --git a/drivers/s390/cio/vfio_ccw_fsm.c b/drivers/s390/cio/vfio_ccw_fsm.c
index e435a9cd92da..8483a266051c 100644
--- a/drivers/s390/cio/vfio_ccw_fsm.c
+++ b/drivers/s390/cio/vfio_ccw_fsm.c
@@ -262,8 +262,7 @@ static void fsm_io_request(struct vfio_ccw_private *private,
 			errstr = "transport mode";
 			goto err_out;
 		}
-		io_region->ret_code = cp_init(&private->cp, mdev_dev(mdev),
-					      orb);
+		io_region->ret_code = cp_init(&private->cp, orb);
 		if (io_region->ret_code) {
 			VFIO_CCW_MSG_EVENT(2,
 					   "%pUl (%x.%x.%04x): cp_init=%d\n",
diff --git a/drivers/s390/cio/vfio_ccw_ops.c b/drivers/s390/cio/vfio_ccw_ops.c
index c4d60cdbf247..b49e2e9db2dc 100644
--- a/drivers/s390/cio/vfio_ccw_ops.c
+++ b/drivers/s390/cio/vfio_ccw_ops.c
@@ -183,7 +183,7 @@ static int vfio_ccw_mdev_open_device(struct vfio_device *vdev)
 
 	private->nb.notifier_call = vfio_ccw_mdev_notifier;
 
-	ret = vfio_register_notifier(vdev->dev, VFIO_IOMMU_NOTIFY,
+	ret = vfio_register_notifier(vdev, VFIO_IOMMU_NOTIFY,
 				     &events, &private->nb);
 	if (ret)
 		return ret;
@@ -204,8 +204,7 @@ static int vfio_ccw_mdev_open_device(struct vfio_device *vdev)
 
 out_unregister:
 	vfio_ccw_unregister_dev_regions(private);
-	vfio_unregister_notifier(vdev->dev, VFIO_IOMMU_NOTIFY,
-				 &private->nb);
+	vfio_unregister_notifier(vdev, VFIO_IOMMU_NOTIFY, &private->nb);
 	return ret;
 }
 
@@ -223,7 +222,7 @@ static void vfio_ccw_mdev_close_device(struct vfio_device *vdev)
 
 	cp_free(&private->cp);
 	vfio_ccw_unregister_dev_regions(private);
-	vfio_unregister_notifier(vdev->dev, VFIO_IOMMU_NOTIFY, &private->nb);
+	vfio_unregister_notifier(vdev, VFIO_IOMMU_NOTIFY, &private->nb);
 }
 
 static ssize_t vfio_ccw_mdev_read_io_region(struct vfio_ccw_private *private,
diff --git a/drivers/s390/crypto/vfio_ap_ops.c b/drivers/s390/crypto/vfio_ap_ops.c
index ee0a3bf8f476..a7d2a95796d3 100644
--- a/drivers/s390/crypto/vfio_ap_ops.c
+++ b/drivers/s390/crypto/vfio_ap_ops.c
@@ -124,8 +124,7 @@ static void vfio_ap_free_aqic_resources(struct vfio_ap_queue *q)
 		q->saved_isc = VFIO_AP_ISC_INVALID;
 	}
 	if (q->saved_pfn && !WARN_ON(!q->matrix_mdev)) {
-		vfio_unpin_pages(mdev_dev(q->matrix_mdev->mdev),
-				 &q->saved_pfn, 1);
+		vfio_unpin_pages(&q->matrix_mdev->vdev, &q->saved_pfn, 1);
 		q->saved_pfn = 0;
 	}
 }
@@ -258,7 +257,7 @@ static struct ap_queue_status vfio_ap_irq_enable(struct vfio_ap_queue *q,
 		return status;
 	}
 
-	ret = vfio_pin_pages(mdev_dev(q->matrix_mdev->mdev), &g_pfn, 1,
+	ret = vfio_pin_pages(&q->matrix_mdev->vdev, &g_pfn, 1,
 			     IOMMU_READ | IOMMU_WRITE, &h_pfn);
 	switch (ret) {
 	case 1:
@@ -301,7 +300,7 @@ static struct ap_queue_status vfio_ap_irq_enable(struct vfio_ap_queue *q,
 		break;
 	case AP_RESPONSE_OTHERWISE_CHANGED:
 		/* We could not modify IRQ setings: clear new configuration */
-		vfio_unpin_pages(mdev_dev(q->matrix_mdev->mdev), &g_pfn, 1);
+		vfio_unpin_pages(&q->matrix_mdev->vdev, &g_pfn, 1);
 		kvm_s390_gisc_unregister(kvm, isc);
 		break;
 	default:
@@ -1250,7 +1249,7 @@ static int vfio_ap_mdev_iommu_notifier(struct notifier_block *nb,
 		struct vfio_iommu_type1_dma_unmap *unmap = data;
 		unsigned long g_pfn = unmap->iova >> PAGE_SHIFT;
 
-		vfio_unpin_pages(mdev_dev(matrix_mdev->mdev), &g_pfn, 1);
+		vfio_unpin_pages(&matrix_mdev->vdev, &g_pfn, 1);
 		return NOTIFY_OK;
 	}
 
@@ -1285,25 +1284,6 @@ static void vfio_ap_mdev_unset_kvm(struct ap_matrix_mdev *matrix_mdev)
 	}
 }
 
-static int vfio_ap_mdev_group_notifier(struct notifier_block *nb,
-				       unsigned long action, void *data)
-{
-	int notify_rc = NOTIFY_OK;
-	struct ap_matrix_mdev *matrix_mdev;
-
-	if (action != VFIO_GROUP_NOTIFY_SET_KVM)
-		return NOTIFY_OK;
-
-	matrix_mdev = container_of(nb, struct ap_matrix_mdev, group_notifier);
-
-	if (!data)
-		vfio_ap_mdev_unset_kvm(matrix_mdev);
-	else if (vfio_ap_mdev_set_kvm(matrix_mdev, data))
-		notify_rc = NOTIFY_DONE;
-
-	return notify_rc;
-}
-
 static struct vfio_ap_queue *vfio_ap_find_queue(int apqn)
 {
 	struct device *dev;
@@ -1403,25 +1383,23 @@ static int vfio_ap_mdev_open_device(struct vfio_device *vdev)
 	unsigned long events;
 	int ret;
 
-	matrix_mdev->group_notifier.notifier_call = vfio_ap_mdev_group_notifier;
-	events = VFIO_GROUP_NOTIFY_SET_KVM;
+	if (!vdev->kvm)
+		return -EINVAL;
 
-	ret = vfio_register_notifier(vdev->dev, VFIO_GROUP_NOTIFY,
-				     &events, &matrix_mdev->group_notifier);
+	ret = vfio_ap_mdev_set_kvm(matrix_mdev, vdev->kvm);
 	if (ret)
 		return ret;
 
 	matrix_mdev->iommu_notifier.notifier_call = vfio_ap_mdev_iommu_notifier;
 	events = VFIO_IOMMU_NOTIFY_DMA_UNMAP;
-	ret = vfio_register_notifier(vdev->dev, VFIO_IOMMU_NOTIFY,
-				     &events, &matrix_mdev->iommu_notifier);
+	ret = vfio_register_notifier(vdev, VFIO_IOMMU_NOTIFY, &events,
+				     &matrix_mdev->iommu_notifier);
 	if (ret)
-		goto out_unregister_group;
+		goto err_kvm;
 	return 0;
 
-out_unregister_group:
-	vfio_unregister_notifier(vdev->dev, VFIO_GROUP_NOTIFY,
-				 &matrix_mdev->group_notifier);
+err_kvm:
+	vfio_ap_mdev_unset_kvm(matrix_mdev);
 	return ret;
 }
 
@@ -1430,10 +1408,8 @@ static void vfio_ap_mdev_close_device(struct vfio_device *vdev)
 	struct ap_matrix_mdev *matrix_mdev =
 		container_of(vdev, struct ap_matrix_mdev, vdev);
 
-	vfio_unregister_notifier(vdev->dev, VFIO_IOMMU_NOTIFY,
+	vfio_unregister_notifier(vdev, VFIO_IOMMU_NOTIFY,
 				 &matrix_mdev->iommu_notifier);
-	vfio_unregister_notifier(vdev->dev, VFIO_GROUP_NOTIFY,
-				 &matrix_mdev->group_notifier);
 	vfio_ap_mdev_unset_kvm(matrix_mdev);
 }
 
diff --git a/drivers/s390/crypto/vfio_ap_private.h b/drivers/s390/crypto/vfio_ap_private.h
index 648fcaf8104a..a26efd804d0d 100644
--- a/drivers/s390/crypto/vfio_ap_private.h
+++ b/drivers/s390/crypto/vfio_ap_private.h
@@ -81,8 +81,6 @@ struct ap_matrix {
  * @node:	allows the ap_matrix_mdev struct to be added to a list
  * @matrix:	the adapters, usage domains and control domains assigned to the
  *		mediated matrix device.
- * @group_notifier: notifier block used for specifying callback function for
- *		    handling the VFIO_GROUP_NOTIFY_SET_KVM event
  * @iommu_notifier: notifier block used for specifying callback function for
  *		    handling the VFIO_IOMMU_NOTIFY_DMA_UNMAP even
  * @kvm:	the struct holding guest's state
@@ -94,7 +92,6 @@ struct ap_matrix_mdev {
 	struct vfio_device vdev;
 	struct list_head node;
 	struct ap_matrix matrix;
-	struct notifier_block group_notifier;
 	struct notifier_block iommu_notifier;
 	struct kvm *kvm;
 	crypto_hook pqap_hook;
diff --git a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
index 767b5d47631a..4def43f5f7b6 100644
--- a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
+++ b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
@@ -337,6 +337,14 @@ static int vf_qm_cache_wb(struct hisi_qm *qm)
 	return 0;
 }
 
+static struct hisi_acc_vf_core_device *hssi_acc_drvdata(struct pci_dev *pdev)
+{
+	struct vfio_pci_core_device *core_device = dev_get_drvdata(&pdev->dev);
+
+	return container_of(core_device, struct hisi_acc_vf_core_device,
+			    core_device);
+}
+
 static void vf_qm_fun_reset(struct hisi_acc_vf_core_device *hisi_acc_vdev,
 			    struct hisi_qm *qm)
 {
@@ -962,7 +970,7 @@ hisi_acc_vfio_pci_get_device_state(struct vfio_device *vdev,
 
 static void hisi_acc_vf_pci_aer_reset_done(struct pci_dev *pdev)
 {
-	struct hisi_acc_vf_core_device *hisi_acc_vdev = dev_get_drvdata(&pdev->dev);
+	struct hisi_acc_vf_core_device *hisi_acc_vdev = hssi_acc_drvdata(pdev);
 
 	if (hisi_acc_vdev->core_device.vdev.migration_flags !=
 				VFIO_MIGRATION_STOP_COPY)
@@ -1274,11 +1282,10 @@ static int hisi_acc_vfio_pci_probe(struct pci_dev *pdev, const struct pci_device
 					  &hisi_acc_vfio_pci_ops);
 	}
 
+	dev_set_drvdata(&pdev->dev, &hisi_acc_vdev->core_device);
 	ret = vfio_pci_core_register_device(&hisi_acc_vdev->core_device);
 	if (ret)
 		goto out_free;
-
-	dev_set_drvdata(&pdev->dev, hisi_acc_vdev);
 	return 0;
 
 out_free:
@@ -1289,7 +1296,7 @@ out_free:
 
 static void hisi_acc_vfio_pci_remove(struct pci_dev *pdev)
 {
-	struct hisi_acc_vf_core_device *hisi_acc_vdev = dev_get_drvdata(&pdev->dev);
+	struct hisi_acc_vf_core_device *hisi_acc_vdev = hssi_acc_drvdata(pdev);
 
 	vfio_pci_core_unregister_device(&hisi_acc_vdev->core_device);
 	vfio_pci_core_uninit_device(&hisi_acc_vdev->core_device);
@@ -1316,6 +1323,7 @@ static struct pci_driver hisi_acc_vfio_pci_driver = {
 	.probe = hisi_acc_vfio_pci_probe,
 	.remove = hisi_acc_vfio_pci_remove,
 	.err_handler = &hisi_acc_vf_err_handlers,
+	.driver_managed_dma = true,
 };
 
 module_pci_driver(hisi_acc_vfio_pci_driver);
diff --git a/drivers/vfio/pci/mlx5/cmd.c b/drivers/vfio/pci/mlx5/cmd.c
index 5c9f9218cc1d..9b9f33ca270a 100644
--- a/drivers/vfio/pci/mlx5/cmd.c
+++ b/drivers/vfio/pci/mlx5/cmd.c
@@ -5,89 +5,157 @@
 
 #include "cmd.h"
 
-int mlx5vf_cmd_suspend_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod)
+static int mlx5vf_cmd_get_vhca_id(struct mlx5_core_dev *mdev, u16 function_id,
+				  u16 *vhca_id);
+
+int mlx5vf_cmd_suspend_vhca(struct mlx5vf_pci_core_device *mvdev, u16 op_mod)
 {
-	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
 	u32 out[MLX5_ST_SZ_DW(suspend_vhca_out)] = {};
 	u32 in[MLX5_ST_SZ_DW(suspend_vhca_in)] = {};
-	int ret;
 
-	if (!mdev)
+	lockdep_assert_held(&mvdev->state_mutex);
+	if (mvdev->mdev_detach)
 		return -ENOTCONN;
 
 	MLX5_SET(suspend_vhca_in, in, opcode, MLX5_CMD_OP_SUSPEND_VHCA);
-	MLX5_SET(suspend_vhca_in, in, vhca_id, vhca_id);
+	MLX5_SET(suspend_vhca_in, in, vhca_id, mvdev->vhca_id);
 	MLX5_SET(suspend_vhca_in, in, op_mod, op_mod);
 
-	ret = mlx5_cmd_exec_inout(mdev, suspend_vhca, in, out);
-	mlx5_vf_put_core_dev(mdev);
-	return ret;
+	return mlx5_cmd_exec_inout(mvdev->mdev, suspend_vhca, in, out);
 }
 
-int mlx5vf_cmd_resume_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod)
+int mlx5vf_cmd_resume_vhca(struct mlx5vf_pci_core_device *mvdev, u16 op_mod)
 {
-	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
 	u32 out[MLX5_ST_SZ_DW(resume_vhca_out)] = {};
 	u32 in[MLX5_ST_SZ_DW(resume_vhca_in)] = {};
-	int ret;
 
-	if (!mdev)
+	lockdep_assert_held(&mvdev->state_mutex);
+	if (mvdev->mdev_detach)
 		return -ENOTCONN;
 
 	MLX5_SET(resume_vhca_in, in, opcode, MLX5_CMD_OP_RESUME_VHCA);
-	MLX5_SET(resume_vhca_in, in, vhca_id, vhca_id);
+	MLX5_SET(resume_vhca_in, in, vhca_id, mvdev->vhca_id);
 	MLX5_SET(resume_vhca_in, in, op_mod, op_mod);
 
-	ret = mlx5_cmd_exec_inout(mdev, resume_vhca, in, out);
-	mlx5_vf_put_core_dev(mdev);
-	return ret;
+	return mlx5_cmd_exec_inout(mvdev->mdev, resume_vhca, in, out);
 }
 
-int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
+int mlx5vf_cmd_query_vhca_migration_state(struct mlx5vf_pci_core_device *mvdev,
 					  size_t *state_size)
 {
-	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
 	u32 out[MLX5_ST_SZ_DW(query_vhca_migration_state_out)] = {};
 	u32 in[MLX5_ST_SZ_DW(query_vhca_migration_state_in)] = {};
 	int ret;
 
-	if (!mdev)
+	lockdep_assert_held(&mvdev->state_mutex);
+	if (mvdev->mdev_detach)
 		return -ENOTCONN;
 
 	MLX5_SET(query_vhca_migration_state_in, in, opcode,
 		 MLX5_CMD_OP_QUERY_VHCA_MIGRATION_STATE);
-	MLX5_SET(query_vhca_migration_state_in, in, vhca_id, vhca_id);
+	MLX5_SET(query_vhca_migration_state_in, in, vhca_id, mvdev->vhca_id);
 	MLX5_SET(query_vhca_migration_state_in, in, op_mod, 0);
 
-	ret = mlx5_cmd_exec_inout(mdev, query_vhca_migration_state, in, out);
+	ret = mlx5_cmd_exec_inout(mvdev->mdev, query_vhca_migration_state, in,
+				  out);
 	if (ret)
-		goto end;
+		return ret;
 
 	*state_size = MLX5_GET(query_vhca_migration_state_out, out,
 			       required_umem_size);
+	return 0;
+}
+
+static int mlx5fv_vf_event(struct notifier_block *nb,
+			   unsigned long event, void *data)
+{
+	struct mlx5vf_pci_core_device *mvdev =
+		container_of(nb, struct mlx5vf_pci_core_device, nb);
+
+	mutex_lock(&mvdev->state_mutex);
+	switch (event) {
+	case MLX5_PF_NOTIFY_ENABLE_VF:
+		mvdev->mdev_detach = false;
+		break;
+	case MLX5_PF_NOTIFY_DISABLE_VF:
+		mlx5vf_disable_fds(mvdev);
+		mvdev->mdev_detach = true;
+		break;
+	default:
+		break;
+	}
+	mlx5vf_state_mutex_unlock(mvdev);
+	return 0;
+}
+
+void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev)
+{
+	if (!mvdev->migrate_cap)
+		return;
+
+	mlx5_sriov_blocking_notifier_unregister(mvdev->mdev, mvdev->vf_id,
+						&mvdev->nb);
+	destroy_workqueue(mvdev->cb_wq);
+}
+
+void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev)
+{
+	struct pci_dev *pdev = mvdev->core_device.pdev;
+	int ret;
+
+	if (!pdev->is_virtfn)
+		return;
+
+	mvdev->mdev = mlx5_vf_get_core_dev(pdev);
+	if (!mvdev->mdev)
+		return;
+
+	if (!MLX5_CAP_GEN(mvdev->mdev, migration))
+		goto end;
+
+	mvdev->vf_id = pci_iov_vf_id(pdev);
+	if (mvdev->vf_id < 0)
+		goto end;
+
+	if (mlx5vf_cmd_get_vhca_id(mvdev->mdev, mvdev->vf_id + 1,
+				   &mvdev->vhca_id))
+		goto end;
+
+	mvdev->cb_wq = alloc_ordered_workqueue("mlx5vf_wq", 0);
+	if (!mvdev->cb_wq)
+		goto end;
+
+	mutex_init(&mvdev->state_mutex);
+	spin_lock_init(&mvdev->reset_lock);
+	mvdev->nb.notifier_call = mlx5fv_vf_event;
+	ret = mlx5_sriov_blocking_notifier_register(mvdev->mdev, mvdev->vf_id,
+						    &mvdev->nb);
+	if (ret) {
+		destroy_workqueue(mvdev->cb_wq);
+		goto end;
+	}
+
+	mvdev->migrate_cap = 1;
+	mvdev->core_device.vdev.migration_flags =
+		VFIO_MIGRATION_STOP_COPY |
+		VFIO_MIGRATION_P2P;
 
 end:
-	mlx5_vf_put_core_dev(mdev);
-	return ret;
+	mlx5_vf_put_core_dev(mvdev->mdev);
 }
 
-int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id)
+static int mlx5vf_cmd_get_vhca_id(struct mlx5_core_dev *mdev, u16 function_id,
+				  u16 *vhca_id)
 {
-	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
 	u32 in[MLX5_ST_SZ_DW(query_hca_cap_in)] = {};
 	int out_size;
 	void *out;
 	int ret;
 
-	if (!mdev)
-		return -ENOTCONN;
-
 	out_size = MLX5_ST_SZ_BYTES(query_hca_cap_out);
 	out = kzalloc(out_size, GFP_KERNEL);
-	if (!out) {
-		ret = -ENOMEM;
-		goto end;
-	}
+	if (!out)
+		return -ENOMEM;
 
 	MLX5_SET(query_hca_cap_in, in, opcode, MLX5_CMD_OP_QUERY_HCA_CAP);
 	MLX5_SET(query_hca_cap_in, in, other_function, 1);
@@ -105,8 +173,6 @@ int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id)
 
 err_exec:
 	kfree(out);
-end:
-	mlx5_vf_put_core_dev(mdev);
 	return ret;
 }
 
@@ -151,21 +217,68 @@ static int _create_state_mkey(struct mlx5_core_dev *mdev, u32 pdn,
 	return err;
 }
 
-int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
+void mlx5vf_mig_file_cleanup_cb(struct work_struct *_work)
+{
+	struct mlx5vf_async_data *async_data = container_of(_work,
+		struct mlx5vf_async_data, work);
+	struct mlx5_vf_migration_file *migf = container_of(async_data,
+		struct mlx5_vf_migration_file, async_data);
+	struct mlx5_core_dev *mdev = migf->mvdev->mdev;
+
+	mutex_lock(&migf->lock);
+	if (async_data->status) {
+		migf->is_err = true;
+		wake_up_interruptible(&migf->poll_wait);
+	}
+	mutex_unlock(&migf->lock);
+
+	mlx5_core_destroy_mkey(mdev, async_data->mkey);
+	dma_unmap_sgtable(mdev->device, &migf->table.sgt, DMA_FROM_DEVICE, 0);
+	mlx5_core_dealloc_pd(mdev, async_data->pdn);
+	kvfree(async_data->out);
+	fput(migf->filp);
+}
+
+static void mlx5vf_save_callback(int status, struct mlx5_async_work *context)
+{
+	struct mlx5vf_async_data *async_data = container_of(context,
+			struct mlx5vf_async_data, cb_work);
+	struct mlx5_vf_migration_file *migf = container_of(async_data,
+			struct mlx5_vf_migration_file, async_data);
+
+	if (!status) {
+		WRITE_ONCE(migf->total_length,
+			   MLX5_GET(save_vhca_state_out, async_data->out,
+				    actual_image_size));
+		wake_up_interruptible(&migf->poll_wait);
+	}
+
+	/*
+	 * The error and the cleanup flows can't run from an
+	 * interrupt context
+	 */
+	async_data->status = status;
+	queue_work(migf->mvdev->cb_wq, &async_data->work);
+}
+
+int mlx5vf_cmd_save_vhca_state(struct mlx5vf_pci_core_device *mvdev,
 			       struct mlx5_vf_migration_file *migf)
 {
-	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
-	u32 out[MLX5_ST_SZ_DW(save_vhca_state_out)] = {};
+	u32 out_size = MLX5_ST_SZ_BYTES(save_vhca_state_out);
 	u32 in[MLX5_ST_SZ_DW(save_vhca_state_in)] = {};
+	struct mlx5vf_async_data *async_data;
+	struct mlx5_core_dev *mdev;
 	u32 pdn, mkey;
 	int err;
 
-	if (!mdev)
+	lockdep_assert_held(&mvdev->state_mutex);
+	if (mvdev->mdev_detach)
 		return -ENOTCONN;
 
+	mdev = mvdev->mdev;
 	err = mlx5_core_alloc_pd(mdev, &pdn);
 	if (err)
-		goto end;
+		return err;
 
 	err = dma_map_sgtable(mdev->device, &migf->table.sgt, DMA_FROM_DEVICE,
 			      0);
@@ -179,45 +292,54 @@ int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
 	MLX5_SET(save_vhca_state_in, in, opcode,
 		 MLX5_CMD_OP_SAVE_VHCA_STATE);
 	MLX5_SET(save_vhca_state_in, in, op_mod, 0);
-	MLX5_SET(save_vhca_state_in, in, vhca_id, vhca_id);
+	MLX5_SET(save_vhca_state_in, in, vhca_id, mvdev->vhca_id);
 	MLX5_SET(save_vhca_state_in, in, mkey, mkey);
 	MLX5_SET(save_vhca_state_in, in, size, migf->total_length);
 
-	err = mlx5_cmd_exec_inout(mdev, save_vhca_state, in, out);
+	async_data = &migf->async_data;
+	async_data->out = kvzalloc(out_size, GFP_KERNEL);
+	if (!async_data->out) {
+		err = -ENOMEM;
+		goto err_out;
+	}
+
+	/* no data exists till the callback comes back */
+	migf->total_length = 0;
+	get_file(migf->filp);
+	async_data->mkey = mkey;
+	async_data->pdn = pdn;
+	err = mlx5_cmd_exec_cb(&migf->async_ctx, in, sizeof(in),
+			       async_data->out,
+			       out_size, mlx5vf_save_callback,
+			       &async_data->cb_work);
 	if (err)
 		goto err_exec;
 
-	migf->total_length =
-		MLX5_GET(save_vhca_state_out, out, actual_image_size);
-
-	mlx5_core_destroy_mkey(mdev, mkey);
-	mlx5_core_dealloc_pd(mdev, pdn);
-	dma_unmap_sgtable(mdev->device, &migf->table.sgt, DMA_FROM_DEVICE, 0);
-	mlx5_vf_put_core_dev(mdev);
-
 	return 0;
 
 err_exec:
+	fput(migf->filp);
+	kvfree(async_data->out);
+err_out:
 	mlx5_core_destroy_mkey(mdev, mkey);
 err_create_mkey:
 	dma_unmap_sgtable(mdev->device, &migf->table.sgt, DMA_FROM_DEVICE, 0);
 err_dma_map:
 	mlx5_core_dealloc_pd(mdev, pdn);
-end:
-	mlx5_vf_put_core_dev(mdev);
 	return err;
 }
 
-int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
+int mlx5vf_cmd_load_vhca_state(struct mlx5vf_pci_core_device *mvdev,
 			       struct mlx5_vf_migration_file *migf)
 {
-	struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
+	struct mlx5_core_dev *mdev;
 	u32 out[MLX5_ST_SZ_DW(save_vhca_state_out)] = {};
 	u32 in[MLX5_ST_SZ_DW(save_vhca_state_in)] = {};
 	u32 pdn, mkey;
 	int err;
 
-	if (!mdev)
+	lockdep_assert_held(&mvdev->state_mutex);
+	if (mvdev->mdev_detach)
 		return -ENOTCONN;
 
 	mutex_lock(&migf->lock);
@@ -226,6 +348,7 @@ int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
 		goto end;
 	}
 
+	mdev = mvdev->mdev;
 	err = mlx5_core_alloc_pd(mdev, &pdn);
 	if (err)
 		goto end;
@@ -241,7 +364,7 @@ int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
 	MLX5_SET(load_vhca_state_in, in, opcode,
 		 MLX5_CMD_OP_LOAD_VHCA_STATE);
 	MLX5_SET(load_vhca_state_in, in, op_mod, 0);
-	MLX5_SET(load_vhca_state_in, in, vhca_id, vhca_id);
+	MLX5_SET(load_vhca_state_in, in, vhca_id, mvdev->vhca_id);
 	MLX5_SET(load_vhca_state_in, in, mkey, mkey);
 	MLX5_SET(load_vhca_state_in, in, size, migf->total_length);
 
@@ -253,7 +376,6 @@ err_mkey:
 err_reg:
 	mlx5_core_dealloc_pd(mdev, pdn);
 end:
-	mlx5_vf_put_core_dev(mdev);
 	mutex_unlock(&migf->lock);
 	return err;
 }
diff --git a/drivers/vfio/pci/mlx5/cmd.h b/drivers/vfio/pci/mlx5/cmd.h
index 1392a11a9cc0..6c3112fdd8b1 100644
--- a/drivers/vfio/pci/mlx5/cmd.h
+++ b/drivers/vfio/pci/mlx5/cmd.h
@@ -7,12 +7,23 @@
 #define MLX5_VFIO_CMD_H
 
 #include <linux/kernel.h>
+#include <linux/vfio_pci_core.h>
 #include <linux/mlx5/driver.h>
 
+struct mlx5vf_async_data {
+	struct mlx5_async_work cb_work;
+	struct work_struct work;
+	int status;
+	u32 pdn;
+	u32 mkey;
+	void *out;
+};
+
 struct mlx5_vf_migration_file {
 	struct file *filp;
 	struct mutex lock;
-	bool disabled;
+	u8 disabled:1;
+	u8 is_err:1;
 
 	struct sg_append_table table;
 	size_t total_length;
@@ -22,15 +33,42 @@ struct mlx5_vf_migration_file {
 	struct scatterlist *last_offset_sg;
 	unsigned int sg_last_entry;
 	unsigned long last_offset;
+	struct mlx5vf_pci_core_device *mvdev;
+	wait_queue_head_t poll_wait;
+	struct mlx5_async_ctx async_ctx;
+	struct mlx5vf_async_data async_data;
+};
+
+struct mlx5vf_pci_core_device {
+	struct vfio_pci_core_device core_device;
+	int vf_id;
+	u16 vhca_id;
+	u8 migrate_cap:1;
+	u8 deferred_reset:1;
+	u8 mdev_detach:1;
+	/* protect migration state */
+	struct mutex state_mutex;
+	enum vfio_device_mig_state mig_state;
+	/* protect the reset_done flow */
+	spinlock_t reset_lock;
+	struct mlx5_vf_migration_file *resuming_migf;
+	struct mlx5_vf_migration_file *saving_migf;
+	struct workqueue_struct *cb_wq;
+	struct notifier_block nb;
+	struct mlx5_core_dev *mdev;
 };
 
-int mlx5vf_cmd_suspend_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
-int mlx5vf_cmd_resume_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
-int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
+int mlx5vf_cmd_suspend_vhca(struct mlx5vf_pci_core_device *mvdev, u16 op_mod);
+int mlx5vf_cmd_resume_vhca(struct mlx5vf_pci_core_device *mvdev, u16 op_mod);
+int mlx5vf_cmd_query_vhca_migration_state(struct mlx5vf_pci_core_device *mvdev,
 					  size_t *state_size);
-int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id);
-int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
+void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev);
+void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev);
+int mlx5vf_cmd_save_vhca_state(struct mlx5vf_pci_core_device *mvdev,
 			       struct mlx5_vf_migration_file *migf);
-int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
+int mlx5vf_cmd_load_vhca_state(struct mlx5vf_pci_core_device *mvdev,
 			       struct mlx5_vf_migration_file *migf);
+void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev);
+void mlx5vf_disable_fds(struct mlx5vf_pci_core_device *mvdev);
+void mlx5vf_mig_file_cleanup_cb(struct work_struct *_work);
 #endif /* MLX5_VFIO_CMD_H */
diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
index bbec5d288fee..0558d0649ddb 100644
--- a/drivers/vfio/pci/mlx5/main.c
+++ b/drivers/vfio/pci/mlx5/main.c
@@ -17,7 +17,6 @@
 #include <linux/uaccess.h>
 #include <linux/vfio.h>
 #include <linux/sched/mm.h>
-#include <linux/vfio_pci_core.h>
 #include <linux/anon_inodes.h>
 
 #include "cmd.h"
@@ -25,19 +24,13 @@
 /* Arbitrary to prevent userspace from consuming endless memory */
 #define MAX_MIGRATION_SIZE (512*1024*1024)
 
-struct mlx5vf_pci_core_device {
-	struct vfio_pci_core_device core_device;
-	u16 vhca_id;
-	u8 migrate_cap:1;
-	u8 deferred_reset:1;
-	/* protect migration state */
-	struct mutex state_mutex;
-	enum vfio_device_mig_state mig_state;
-	/* protect the reset_done flow */
-	spinlock_t reset_lock;
-	struct mlx5_vf_migration_file *resuming_migf;
-	struct mlx5_vf_migration_file *saving_migf;
-};
+static struct mlx5vf_pci_core_device *mlx5vf_drvdata(struct pci_dev *pdev)
+{
+	struct vfio_pci_core_device *core_device = dev_get_drvdata(&pdev->dev);
+
+	return container_of(core_device, struct mlx5vf_pci_core_device,
+			    core_device);
+}
 
 static struct page *
 mlx5vf_get_migration_page(struct mlx5_vf_migration_file *migf,
@@ -149,12 +142,22 @@ static ssize_t mlx5vf_save_read(struct file *filp, char __user *buf, size_t len,
 		return -ESPIPE;
 	pos = &filp->f_pos;
 
+	if (!(filp->f_flags & O_NONBLOCK)) {
+		if (wait_event_interruptible(migf->poll_wait,
+			     READ_ONCE(migf->total_length) || migf->is_err))
+			return -ERESTARTSYS;
+	}
+
 	mutex_lock(&migf->lock);
+	if ((filp->f_flags & O_NONBLOCK) && !READ_ONCE(migf->total_length)) {
+		done = -EAGAIN;
+		goto out_unlock;
+	}
 	if (*pos > migf->total_length) {
 		done = -EINVAL;
 		goto out_unlock;
 	}
-	if (migf->disabled) {
+	if (migf->disabled || migf->is_err) {
 		done = -ENODEV;
 		goto out_unlock;
 	}
@@ -194,9 +197,28 @@ out_unlock:
 	return done;
 }
 
+static __poll_t mlx5vf_save_poll(struct file *filp,
+				 struct poll_table_struct *wait)
+{
+	struct mlx5_vf_migration_file *migf = filp->private_data;
+	__poll_t pollflags = 0;
+
+	poll_wait(filp, &migf->poll_wait, wait);
+
+	mutex_lock(&migf->lock);
+	if (migf->disabled || migf->is_err)
+		pollflags = EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
+	else if (READ_ONCE(migf->total_length))
+		pollflags = EPOLLIN | EPOLLRDNORM;
+	mutex_unlock(&migf->lock);
+
+	return pollflags;
+}
+
 static const struct file_operations mlx5vf_save_fops = {
 	.owner = THIS_MODULE,
 	.read = mlx5vf_save_read,
+	.poll = mlx5vf_save_poll,
 	.release = mlx5vf_release_file,
 	.llseek = no_llseek,
 };
@@ -222,9 +244,11 @@ mlx5vf_pci_save_device_data(struct mlx5vf_pci_core_device *mvdev)
 
 	stream_open(migf->filp->f_inode, migf->filp);
 	mutex_init(&migf->lock);
-
-	ret = mlx5vf_cmd_query_vhca_migration_state(
-		mvdev->core_device.pdev, mvdev->vhca_id, &migf->total_length);
+	init_waitqueue_head(&migf->poll_wait);
+	mlx5_cmd_init_async_ctx(mvdev->mdev, &migf->async_ctx);
+	INIT_WORK(&migf->async_data.work, mlx5vf_mig_file_cleanup_cb);
+	ret = mlx5vf_cmd_query_vhca_migration_state(mvdev,
+						    &migf->total_length);
 	if (ret)
 		goto out_free;
 
@@ -233,8 +257,8 @@ mlx5vf_pci_save_device_data(struct mlx5vf_pci_core_device *mvdev)
 	if (ret)
 		goto out_free;
 
-	ret = mlx5vf_cmd_save_vhca_state(mvdev->core_device.pdev,
-					 mvdev->vhca_id, migf);
+	migf->mvdev = mvdev;
+	ret = mlx5vf_cmd_save_vhca_state(mvdev, migf);
 	if (ret)
 		goto out_free;
 	return migf;
@@ -339,7 +363,7 @@ mlx5vf_pci_resume_device_data(struct mlx5vf_pci_core_device *mvdev)
 	return migf;
 }
 
-static void mlx5vf_disable_fds(struct mlx5vf_pci_core_device *mvdev)
+void mlx5vf_disable_fds(struct mlx5vf_pci_core_device *mvdev)
 {
 	if (mvdev->resuming_migf) {
 		mlx5vf_disable_fd(mvdev->resuming_migf);
@@ -347,6 +371,8 @@ static void mlx5vf_disable_fds(struct mlx5vf_pci_core_device *mvdev)
 		mvdev->resuming_migf = NULL;
 	}
 	if (mvdev->saving_migf) {
+		mlx5_cmd_cleanup_async_ctx(&mvdev->saving_migf->async_ctx);
+		cancel_work_sync(&mvdev->saving_migf->async_data.work);
 		mlx5vf_disable_fd(mvdev->saving_migf);
 		fput(mvdev->saving_migf->filp);
 		mvdev->saving_migf = NULL;
@@ -361,8 +387,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
 	int ret;
 
 	if (cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_STOP) {
-		ret = mlx5vf_cmd_suspend_vhca(
-			mvdev->core_device.pdev, mvdev->vhca_id,
+		ret = mlx5vf_cmd_suspend_vhca(mvdev,
 			MLX5_SUSPEND_VHCA_IN_OP_MOD_SUSPEND_RESPONDER);
 		if (ret)
 			return ERR_PTR(ret);
@@ -370,8 +395,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
 	}
 
 	if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_RUNNING_P2P) {
-		ret = mlx5vf_cmd_resume_vhca(
-			mvdev->core_device.pdev, mvdev->vhca_id,
+		ret = mlx5vf_cmd_resume_vhca(mvdev,
 			MLX5_RESUME_VHCA_IN_OP_MOD_RESUME_RESPONDER);
 		if (ret)
 			return ERR_PTR(ret);
@@ -379,8 +403,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
 	}
 
 	if (cur == VFIO_DEVICE_STATE_RUNNING && new == VFIO_DEVICE_STATE_RUNNING_P2P) {
-		ret = mlx5vf_cmd_suspend_vhca(
-			mvdev->core_device.pdev, mvdev->vhca_id,
+		ret = mlx5vf_cmd_suspend_vhca(mvdev,
 			MLX5_SUSPEND_VHCA_IN_OP_MOD_SUSPEND_INITIATOR);
 		if (ret)
 			return ERR_PTR(ret);
@@ -388,8 +411,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
 	}
 
 	if (cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_RUNNING) {
-		ret = mlx5vf_cmd_resume_vhca(
-			mvdev->core_device.pdev, mvdev->vhca_id,
+		ret = mlx5vf_cmd_resume_vhca(mvdev,
 			MLX5_RESUME_VHCA_IN_OP_MOD_RESUME_INITIATOR);
 		if (ret)
 			return ERR_PTR(ret);
@@ -424,8 +446,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
 	}
 
 	if (cur == VFIO_DEVICE_STATE_RESUMING && new == VFIO_DEVICE_STATE_STOP) {
-		ret = mlx5vf_cmd_load_vhca_state(mvdev->core_device.pdev,
-						 mvdev->vhca_id,
+		ret = mlx5vf_cmd_load_vhca_state(mvdev,
 						 mvdev->resuming_migf);
 		if (ret)
 			return ERR_PTR(ret);
@@ -444,7 +465,7 @@ mlx5vf_pci_step_device_state_locked(struct mlx5vf_pci_core_device *mvdev,
  * This function is called in all state_mutex unlock cases to
  * handle a 'deferred_reset' if exists.
  */
-static void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
+void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
 {
 again:
 	spin_lock(&mvdev->reset_lock);
@@ -505,7 +526,7 @@ static int mlx5vf_pci_get_device_state(struct vfio_device *vdev,
 
 static void mlx5vf_pci_aer_reset_done(struct pci_dev *pdev)
 {
-	struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
+	struct mlx5vf_pci_core_device *mvdev = mlx5vf_drvdata(pdev);
 
 	if (!mvdev->migrate_cap)
 		return;
@@ -532,34 +553,16 @@ static int mlx5vf_pci_open_device(struct vfio_device *core_vdev)
 	struct mlx5vf_pci_core_device *mvdev = container_of(
 		core_vdev, struct mlx5vf_pci_core_device, core_device.vdev);
 	struct vfio_pci_core_device *vdev = &mvdev->core_device;
-	int vf_id;
 	int ret;
 
 	ret = vfio_pci_core_enable(vdev);
 	if (ret)
 		return ret;
 
-	if (!mvdev->migrate_cap) {
-		vfio_pci_core_finish_enable(vdev);
-		return 0;
-	}
-
-	vf_id = pci_iov_vf_id(vdev->pdev);
-	if (vf_id < 0) {
-		ret = vf_id;
-		goto out_disable;
-	}
-
-	ret = mlx5vf_cmd_get_vhca_id(vdev->pdev, vf_id + 1, &mvdev->vhca_id);
-	if (ret)
-		goto out_disable;
-
-	mvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
+	if (mvdev->migrate_cap)
+		mvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
 	vfio_pci_core_finish_enable(vdev);
 	return 0;
-out_disable:
-	vfio_pci_core_disable(vdev);
-	return ret;
 }
 
 static void mlx5vf_pci_close_device(struct vfio_device *core_vdev)
@@ -596,32 +599,15 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
 	if (!mvdev)
 		return -ENOMEM;
 	vfio_pci_core_init_device(&mvdev->core_device, pdev, &mlx5vf_pci_ops);
-
-	if (pdev->is_virtfn) {
-		struct mlx5_core_dev *mdev =
-			mlx5_vf_get_core_dev(pdev);
-
-		if (mdev) {
-			if (MLX5_CAP_GEN(mdev, migration)) {
-				mvdev->migrate_cap = 1;
-				mvdev->core_device.vdev.migration_flags =
-					VFIO_MIGRATION_STOP_COPY |
-					VFIO_MIGRATION_P2P;
-				mutex_init(&mvdev->state_mutex);
-				spin_lock_init(&mvdev->reset_lock);
-			}
-			mlx5_vf_put_core_dev(mdev);
-		}
-	}
-
+	mlx5vf_cmd_set_migratable(mvdev);
+	dev_set_drvdata(&pdev->dev, &mvdev->core_device);
 	ret = vfio_pci_core_register_device(&mvdev->core_device);
 	if (ret)
 		goto out_free;
-
-	dev_set_drvdata(&pdev->dev, mvdev);
 	return 0;
 
 out_free:
+	mlx5vf_cmd_remove_migratable(mvdev);
 	vfio_pci_core_uninit_device(&mvdev->core_device);
 	kfree(mvdev);
 	return ret;
@@ -629,9 +615,10 @@ out_free:
 
 static void mlx5vf_pci_remove(struct pci_dev *pdev)
 {
-	struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
+	struct mlx5vf_pci_core_device *mvdev = mlx5vf_drvdata(pdev);
 
 	vfio_pci_core_unregister_device(&mvdev->core_device);
+	mlx5vf_cmd_remove_migratable(mvdev);
 	vfio_pci_core_uninit_device(&mvdev->core_device);
 	kfree(mvdev);
 }
@@ -654,6 +641,7 @@ static struct pci_driver mlx5vf_pci_driver = {
 	.probe = mlx5vf_pci_probe,
 	.remove = mlx5vf_pci_remove,
 	.err_handler = &mlx5vf_err_handlers,
+	.driver_managed_dma = true,
 };
 
 static void __exit mlx5vf_pci_cleanup(void)
diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 58839206d1ca..4d1a97415a27 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -151,10 +151,10 @@ static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 		return -ENOMEM;
 	vfio_pci_core_init_device(vdev, pdev, &vfio_pci_ops);
 
+	dev_set_drvdata(&pdev->dev, vdev);
 	ret = vfio_pci_core_register_device(vdev);
 	if (ret)
 		goto out_free;
-	dev_set_drvdata(&pdev->dev, vdev);
 	return 0;
 
 out_free:
@@ -174,10 +174,12 @@ static void vfio_pci_remove(struct pci_dev *pdev)
 
 static int vfio_pci_sriov_configure(struct pci_dev *pdev, int nr_virtfn)
 {
+	struct vfio_pci_core_device *vdev = dev_get_drvdata(&pdev->dev);
+
 	if (!enable_sriov)
 		return -ENOENT;
 
-	return vfio_pci_core_sriov_configure(pdev, nr_virtfn);
+	return vfio_pci_core_sriov_configure(vdev, nr_virtfn);
 }
 
 static const struct pci_device_id vfio_pci_table[] = {
diff --git a/drivers/vfio/pci/vfio_pci_config.c b/drivers/vfio/pci/vfio_pci_config.c
index 6e58b4bf7a60..9343f597182d 100644
--- a/drivers/vfio/pci/vfio_pci_config.c
+++ b/drivers/vfio/pci/vfio_pci_config.c
@@ -402,11 +402,14 @@ bool __vfio_pci_memory_enabled(struct vfio_pci_core_device *vdev)
 	u16 cmd = le16_to_cpu(*(__le16 *)&vdev->vconfig[PCI_COMMAND]);
 
 	/*
+	 * Memory region cannot be accessed if device power state is D3.
+	 *
 	 * SR-IOV VF memory enable is handled by the MSE bit in the
 	 * PF SR-IOV capability, there's therefore no need to trigger
 	 * faults based on the virtual value.
 	 */
-	return pdev->no_command_memory || (cmd & PCI_COMMAND_MEMORY);
+	return pdev->current_state < PCI_D3hot &&
+	       (pdev->no_command_memory || (cmd & PCI_COMMAND_MEMORY));
 }
 
 /*
@@ -692,6 +695,22 @@ static int __init init_pci_cap_basic_perm(struct perm_bits *perm)
 	return 0;
 }
 
+/*
+ * It takes all the required locks to protect the access of power related
+ * variables and then invokes vfio_pci_set_power_state().
+ */
+static void vfio_lock_and_set_power_state(struct vfio_pci_core_device *vdev,
+					  pci_power_t state)
+{
+	if (state >= PCI_D3hot)
+		vfio_pci_zap_and_down_write_memory_lock(vdev);
+	else
+		down_write(&vdev->memory_lock);
+
+	vfio_pci_set_power_state(vdev, state);
+	up_write(&vdev->memory_lock);
+}
+
 static int vfio_pm_config_write(struct vfio_pci_core_device *vdev, int pos,
 				int count, struct perm_bits *perm,
 				int offset, __le32 val)
@@ -718,7 +737,7 @@ static int vfio_pm_config_write(struct vfio_pci_core_device *vdev, int pos,
 			break;
 		}
 
-		vfio_pci_set_power_state(vdev, state);
+		vfio_lock_and_set_power_state(vdev, state);
 	}
 
 	return count;
@@ -739,11 +758,28 @@ static int __init init_pci_cap_pm_perm(struct perm_bits *perm)
 	p_setb(perm, PCI_CAP_LIST_NEXT, (u8)ALL_VIRT, NO_WRITE);
 
 	/*
+	 * The guests can't process PME events. If any PME event will be
+	 * generated, then it will be mostly handled in the host and the
+	 * host will clear the PME_STATUS. So virtualize PME_Support bits.
+	 * The vconfig bits will be cleared during device capability
+	 * initialization.
+	 */
+	p_setw(perm, PCI_PM_PMC, PCI_PM_CAP_PME_MASK, NO_WRITE);
+
+	/*
 	 * Power management is defined *per function*, so we can let
 	 * the user change power state, but we trap and initiate the
 	 * change ourselves, so the state bits are read-only.
+	 *
+	 * The guest can't process PME from D3cold so virtualize PME_Status
+	 * and PME_En bits. The vconfig bits will be cleared during device
+	 * capability initialization.
 	 */
-	p_setd(perm, PCI_PM_CTRL, NO_VIRT, ~PCI_PM_CTRL_STATE_MASK);
+	p_setd(perm, PCI_PM_CTRL,
+	       PCI_PM_CTRL_PME_ENABLE | PCI_PM_CTRL_PME_STATUS,
+	       ~(PCI_PM_CTRL_PME_ENABLE | PCI_PM_CTRL_PME_STATUS |
+		 PCI_PM_CTRL_STATE_MASK));
+
 	return 0;
 }
 
@@ -1412,6 +1448,17 @@ static int vfio_ext_cap_len(struct vfio_pci_core_device *vdev, u16 ecap, u16 epo
 	return 0;
 }
 
+static void vfio_update_pm_vconfig_bytes(struct vfio_pci_core_device *vdev,
+					 int offset)
+{
+	__le16 *pmc = (__le16 *)&vdev->vconfig[offset + PCI_PM_PMC];
+	__le16 *ctrl = (__le16 *)&vdev->vconfig[offset + PCI_PM_CTRL];
+
+	/* Clear vconfig PME_Support, PME_Status, and PME_En bits */
+	*pmc &= ~cpu_to_le16(PCI_PM_CAP_PME_MASK);
+	*ctrl &= ~cpu_to_le16(PCI_PM_CTRL_PME_ENABLE | PCI_PM_CTRL_PME_STATUS);
+}
+
 static int vfio_fill_vconfig_bytes(struct vfio_pci_core_device *vdev,
 				   int offset, int size)
 {
@@ -1535,6 +1582,9 @@ static int vfio_cap_init(struct vfio_pci_core_device *vdev)
 		if (ret)
 			return ret;
 
+		if (cap == PCI_CAP_ID_PM)
+			vfio_update_pm_vconfig_bytes(vdev, pos);
+
 		prev = &vdev->vconfig[pos + PCI_CAP_LIST_NEXT];
 		pos = next;
 		caps++;
diff --git a/drivers/vfio/pci/vfio_pci_core.c b/drivers/vfio/pci/vfio_pci_core.c
index 06b6f3594a13..a0d69ddaf90d 100644
--- a/drivers/vfio/pci/vfio_pci_core.c
+++ b/drivers/vfio/pci/vfio_pci_core.c
@@ -156,7 +156,7 @@ no_mmap:
 }
 
 struct vfio_pci_group_info;
-static bool vfio_pci_dev_set_try_reset(struct vfio_device_set *dev_set);
+static void vfio_pci_dev_set_try_reset(struct vfio_device_set *dev_set);
 static int vfio_pci_dev_set_hot_reset(struct vfio_device_set *dev_set,
 				      struct vfio_pci_group_info *groups);
 
@@ -217,6 +217,10 @@ int vfio_pci_set_power_state(struct vfio_pci_core_device *vdev, pci_power_t stat
 	bool needs_restore = false, needs_save = false;
 	int ret;
 
+	/* Prevent changing power state for PFs with VFs enabled */
+	if (pci_num_vf(pdev) && state > PCI_D0)
+		return -EBUSY;
+
 	if (vdev->needs_pm_restore) {
 		if (pdev->current_state < PCI_D3hot && state >= PCI_D3hot) {
 			pci_save_state(pdev);
@@ -255,6 +259,17 @@ int vfio_pci_set_power_state(struct vfio_pci_core_device *vdev, pci_power_t stat
 	return ret;
 }
 
+/*
+ * The dev_pm_ops needs to be provided to make pci-driver runtime PM working,
+ * so use structure without any callbacks.
+ *
+ * The pci-driver core runtime PM routines always save the device state
+ * before going into suspended state. If the device is going into low power
+ * state with only with runtime PM ops, then no explicit handling is needed
+ * for the devices which have NoSoftRst-.
+ */
+static const struct dev_pm_ops vfio_pci_core_pm_ops = { };
+
 int vfio_pci_core_enable(struct vfio_pci_core_device *vdev)
 {
 	struct pci_dev *pdev = vdev->pdev;
@@ -262,21 +277,23 @@ int vfio_pci_core_enable(struct vfio_pci_core_device *vdev)
 	u16 cmd;
 	u8 msix_pos;
 
-	vfio_pci_set_power_state(vdev, PCI_D0);
+	if (!disable_idle_d3) {
+		ret = pm_runtime_resume_and_get(&pdev->dev);
+		if (ret < 0)
+			return ret;
+	}
 
 	/* Don't allow our initial saved state to include busmaster */
 	pci_clear_master(pdev);
 
 	ret = pci_enable_device(pdev);
 	if (ret)
-		return ret;
+		goto out_power;
 
 	/* If reset fails because of the device lock, fail this path entirely */
 	ret = pci_try_reset_function(pdev);
-	if (ret == -EAGAIN) {
-		pci_disable_device(pdev);
-		return ret;
-	}
+	if (ret == -EAGAIN)
+		goto out_disable_device;
 
 	vdev->reset_works = !ret;
 	pci_save_state(pdev);
@@ -300,12 +317,8 @@ int vfio_pci_core_enable(struct vfio_pci_core_device *vdev)
 	}
 
 	ret = vfio_config_init(vdev);
-	if (ret) {
-		kfree(vdev->pci_saved_state);
-		vdev->pci_saved_state = NULL;
-		pci_disable_device(pdev);
-		return ret;
-	}
+	if (ret)
+		goto out_free_state;
 
 	msix_pos = pdev->msix_cap;
 	if (msix_pos) {
@@ -326,6 +339,16 @@ int vfio_pci_core_enable(struct vfio_pci_core_device *vdev)
 
 
 	return 0;
+
+out_free_state:
+	kfree(vdev->pci_saved_state);
+	vdev->pci_saved_state = NULL;
+out_disable_device:
+	pci_disable_device(pdev);
+out_power:
+	if (!disable_idle_d3)
+		pm_runtime_put(&pdev->dev);
+	return ret;
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_enable);
 
@@ -433,8 +456,11 @@ void vfio_pci_core_disable(struct vfio_pci_core_device *vdev)
 out:
 	pci_disable_device(pdev);
 
-	if (!vfio_pci_dev_set_try_reset(vdev->vdev.dev_set) && !disable_idle_d3)
-		vfio_pci_set_power_state(vdev, PCI_D3hot);
+	vfio_pci_dev_set_try_reset(vdev->vdev.dev_set);
+
+	/* Put the pm-runtime usage counter acquired during enable */
+	if (!disable_idle_d3)
+		pm_runtime_put(&pdev->dev);
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_disable);
 
@@ -556,7 +582,7 @@ static int vfio_pci_fill_devs(struct pci_dev *pdev, void *data)
 
 struct vfio_pci_group_info {
 	int count;
-	struct vfio_group **groups;
+	struct file **files;
 };
 
 static bool vfio_pci_dev_below_slot(struct pci_dev *pdev, struct pci_slot *slot)
@@ -1018,10 +1044,10 @@ reset_info_exit:
 	} else if (cmd == VFIO_DEVICE_PCI_HOT_RESET) {
 		struct vfio_pci_hot_reset hdr;
 		int32_t *group_fds;
-		struct vfio_group **groups;
+		struct file **files;
 		struct vfio_pci_group_info info;
 		bool slot = false;
-		int group_idx, count = 0, ret = 0;
+		int file_idx, count = 0, ret = 0;
 
 		minsz = offsetofend(struct vfio_pci_hot_reset, count);
 
@@ -1054,17 +1080,17 @@ reset_info_exit:
 			return -EINVAL;
 
 		group_fds = kcalloc(hdr.count, sizeof(*group_fds), GFP_KERNEL);
-		groups = kcalloc(hdr.count, sizeof(*groups), GFP_KERNEL);
-		if (!group_fds || !groups) {
+		files = kcalloc(hdr.count, sizeof(*files), GFP_KERNEL);
+		if (!group_fds || !files) {
 			kfree(group_fds);
-			kfree(groups);
+			kfree(files);
 			return -ENOMEM;
 		}
 
 		if (copy_from_user(group_fds, (void __user *)(arg + minsz),
 				   hdr.count * sizeof(*group_fds))) {
 			kfree(group_fds);
-			kfree(groups);
+			kfree(files);
 			return -EFAULT;
 		}
 
@@ -1073,22 +1099,22 @@ reset_info_exit:
 		 * user interface and store the group and iommu ID.  This
 		 * ensures the group is held across the reset.
 		 */
-		for (group_idx = 0; group_idx < hdr.count; group_idx++) {
-			struct vfio_group *group;
-			struct fd f = fdget(group_fds[group_idx]);
-			if (!f.file) {
+		for (file_idx = 0; file_idx < hdr.count; file_idx++) {
+			struct file *file = fget(group_fds[file_idx]);
+
+			if (!file) {
 				ret = -EBADF;
 				break;
 			}
 
-			group = vfio_group_get_external_user(f.file);
-			fdput(f);
-			if (IS_ERR(group)) {
-				ret = PTR_ERR(group);
+			/* Ensure the FD is a vfio group FD.*/
+			if (!vfio_file_iommu_group(file)) {
+				fput(file);
+				ret = -EINVAL;
 				break;
 			}
 
-			groups[group_idx] = group;
+			files[file_idx] = file;
 		}
 
 		kfree(group_fds);
@@ -1098,15 +1124,15 @@ reset_info_exit:
 			goto hot_reset_release;
 
 		info.count = hdr.count;
-		info.groups = groups;
+		info.files = files;
 
 		ret = vfio_pci_dev_set_hot_reset(vdev->vdev.dev_set, &info);
 
 hot_reset_release:
-		for (group_idx--; group_idx >= 0; group_idx--)
-			vfio_group_put_external_user(groups[group_idx]);
+		for (file_idx--; file_idx >= 0; file_idx--)
+			fput(files[file_idx]);
 
-		kfree(groups);
+		kfree(files);
 		return ret;
 	} else if (cmd == VFIO_DEVICE_IOEVENTFD) {
 		struct vfio_device_ioeventfd ioeventfd;
@@ -1819,8 +1845,13 @@ EXPORT_SYMBOL_GPL(vfio_pci_core_uninit_device);
 int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
 {
 	struct pci_dev *pdev = vdev->pdev;
+	struct device *dev = &pdev->dev;
 	int ret;
 
+	/* Drivers must set the vfio_pci_core_device to their drvdata */
+	if (WARN_ON(vdev != dev_get_drvdata(dev)))
+		return -EINVAL;
+
 	if (pdev->hdr_type != PCI_HEADER_TYPE_NORMAL)
 		return -EINVAL;
 
@@ -1860,19 +1891,21 @@ int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
 
 	vfio_pci_probe_power_state(vdev);
 
-	if (!disable_idle_d3) {
-		/*
-		 * pci-core sets the device power state to an unknown value at
-		 * bootup and after being removed from a driver.  The only
-		 * transition it allows from this unknown state is to D0, which
-		 * typically happens when a driver calls pci_enable_device().
-		 * We're not ready to enable the device yet, but we do want to
-		 * be able to get to D3.  Therefore first do a D0 transition
-		 * before going to D3.
-		 */
-		vfio_pci_set_power_state(vdev, PCI_D0);
-		vfio_pci_set_power_state(vdev, PCI_D3hot);
-	}
+	/*
+	 * pci-core sets the device power state to an unknown value at
+	 * bootup and after being removed from a driver.  The only
+	 * transition it allows from this unknown state is to D0, which
+	 * typically happens when a driver calls pci_enable_device().
+	 * We're not ready to enable the device yet, but we do want to
+	 * be able to get to D3.  Therefore first do a D0 transition
+	 * before enabling runtime PM.
+	 */
+	vfio_pci_set_power_state(vdev, PCI_D0);
+
+	dev->driver->pm = &vfio_pci_core_pm_ops;
+	pm_runtime_allow(dev);
+	if (!disable_idle_d3)
+		pm_runtime_put(dev);
 
 	ret = vfio_register_group_dev(&vdev->vdev);
 	if (ret)
@@ -1881,7 +1914,9 @@ int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
 
 out_power:
 	if (!disable_idle_d3)
-		vfio_pci_set_power_state(vdev, PCI_D0);
+		pm_runtime_get_noresume(dev);
+
+	pm_runtime_forbid(dev);
 out_vf:
 	vfio_pci_vf_uninit(vdev);
 	return ret;
@@ -1890,9 +1925,7 @@ EXPORT_SYMBOL_GPL(vfio_pci_core_register_device);
 
 void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev)
 {
-	struct pci_dev *pdev = vdev->pdev;
-
-	vfio_pci_core_sriov_configure(pdev, 0);
+	vfio_pci_core_sriov_configure(vdev, 0);
 
 	vfio_unregister_group_dev(&vdev->vdev);
 
@@ -1900,21 +1933,16 @@ void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev)
 	vfio_pci_vga_uninit(vdev);
 
 	if (!disable_idle_d3)
-		vfio_pci_set_power_state(vdev, PCI_D0);
+		pm_runtime_get_noresume(&vdev->pdev->dev);
+
+	pm_runtime_forbid(&vdev->pdev->dev);
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_unregister_device);
 
 pci_ers_result_t vfio_pci_core_aer_err_detected(struct pci_dev *pdev,
 						pci_channel_state_t state)
 {
-	struct vfio_pci_core_device *vdev;
-	struct vfio_device *device;
-
-	device = vfio_device_get_from_dev(&pdev->dev);
-	if (device == NULL)
-		return PCI_ERS_RESULT_DISCONNECT;
-
-	vdev = container_of(device, struct vfio_pci_core_device, vdev);
+	struct vfio_pci_core_device *vdev = dev_get_drvdata(&pdev->dev);
 
 	mutex_lock(&vdev->igate);
 
@@ -1923,26 +1951,18 @@ pci_ers_result_t vfio_pci_core_aer_err_detected(struct pci_dev *pdev,
 
 	mutex_unlock(&vdev->igate);
 
-	vfio_device_put(device);
-
 	return PCI_ERS_RESULT_CAN_RECOVER;
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_aer_err_detected);
 
-int vfio_pci_core_sriov_configure(struct pci_dev *pdev, int nr_virtfn)
+int vfio_pci_core_sriov_configure(struct vfio_pci_core_device *vdev,
+				  int nr_virtfn)
 {
-	struct vfio_pci_core_device *vdev;
-	struct vfio_device *device;
+	struct pci_dev *pdev = vdev->pdev;
 	int ret = 0;
 
 	device_lock_assert(&pdev->dev);
 
-	device = vfio_device_get_from_dev(&pdev->dev);
-	if (!device)
-		return -ENODEV;
-
-	vdev = container_of(device, struct vfio_pci_core_device, vdev);
-
 	if (nr_virtfn) {
 		mutex_lock(&vfio_pci_sriov_pfs_mutex);
 		/*
@@ -1957,22 +1977,42 @@ int vfio_pci_core_sriov_configure(struct pci_dev *pdev, int nr_virtfn)
 		}
 		list_add_tail(&vdev->sriov_pfs_item, &vfio_pci_sriov_pfs);
 		mutex_unlock(&vfio_pci_sriov_pfs_mutex);
-		ret = pci_enable_sriov(pdev, nr_virtfn);
+
+		/*
+		 * The PF power state should always be higher than the VF power
+		 * state. The PF can be in low power state either with runtime
+		 * power management (when there is no user) or PCI_PM_CTRL
+		 * register write by the user. If PF is in the low power state,
+		 * then change the power state to D0 first before enabling
+		 * SR-IOV. Also, this function can be called at any time, and
+		 * userspace PCI_PM_CTRL write can race against this code path,
+		 * so protect the same with 'memory_lock'.
+		 */
+		ret = pm_runtime_resume_and_get(&pdev->dev);
 		if (ret)
 			goto out_del;
-		ret = nr_virtfn;
-		goto out_put;
+
+		down_write(&vdev->memory_lock);
+		vfio_pci_set_power_state(vdev, PCI_D0);
+		ret = pci_enable_sriov(pdev, nr_virtfn);
+		up_write(&vdev->memory_lock);
+		if (ret) {
+			pm_runtime_put(&pdev->dev);
+			goto out_del;
+		}
+		return nr_virtfn;
 	}
 
-	pci_disable_sriov(pdev);
+	if (pci_num_vf(pdev)) {
+		pci_disable_sriov(pdev);
+		pm_runtime_put(&pdev->dev);
+	}
 
 out_del:
 	mutex_lock(&vfio_pci_sriov_pfs_mutex);
 	list_del_init(&vdev->sriov_pfs_item);
 out_unlock:
 	mutex_unlock(&vfio_pci_sriov_pfs_mutex);
-out_put:
-	vfio_device_put(device);
 	return ret;
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_sriov_configure);
@@ -1988,7 +2028,7 @@ static bool vfio_dev_in_groups(struct vfio_pci_core_device *vdev,
 	unsigned int i;
 
 	for (i = 0; i < groups->count; i++)
-		if (groups->groups[i] == vdev->vdev.group)
+		if (vfio_file_has_dev(groups->files[i], &vdev->vdev))
 			return true;
 	return false;
 }
@@ -2041,6 +2081,27 @@ vfio_pci_dev_set_resettable(struct vfio_device_set *dev_set)
 	return pdev;
 }
 
+static int vfio_pci_dev_set_pm_runtime_get(struct vfio_device_set *dev_set)
+{
+	struct vfio_pci_core_device *cur;
+	int ret;
+
+	list_for_each_entry(cur, &dev_set->device_list, vdev.dev_set_list) {
+		ret = pm_runtime_resume_and_get(&cur->pdev->dev);
+		if (ret)
+			goto unwind;
+	}
+
+	return 0;
+
+unwind:
+	list_for_each_entry_continue_reverse(cur, &dev_set->device_list,
+					     vdev.dev_set_list)
+		pm_runtime_put(&cur->pdev->dev);
+
+	return ret;
+}
+
 /*
  * We need to get memory_lock for each device, but devices can share mmap_lock,
  * therefore we need to zap and hold the vma_lock for each device, and only then
@@ -2147,43 +2208,38 @@ static bool vfio_pci_dev_set_needs_reset(struct vfio_device_set *dev_set)
  *  - At least one of the affected devices is marked dirty via
  *    needs_reset (such as by lack of FLR support)
  * Then attempt to perform that bus or slot reset.
- * Returns true if the dev_set was reset.
  */
-static bool vfio_pci_dev_set_try_reset(struct vfio_device_set *dev_set)
+static void vfio_pci_dev_set_try_reset(struct vfio_device_set *dev_set)
 {
 	struct vfio_pci_core_device *cur;
 	struct pci_dev *pdev;
-	int ret;
+	bool reset_done = false;
 
 	if (!vfio_pci_dev_set_needs_reset(dev_set))
-		return false;
+		return;
 
 	pdev = vfio_pci_dev_set_resettable(dev_set);
 	if (!pdev)
-		return false;
+		return;
 
 	/*
-	 * The pci_reset_bus() will reset all the devices in the bus.
-	 * The power state can be non-D0 for some of the devices in the bus.
-	 * For these devices, the pci_reset_bus() will internally set
-	 * the power state to D0 without vfio driver involvement.
-	 * For the devices which have NoSoftRst-, the reset function can
-	 * cause the PCI config space reset without restoring the original
-	 * state (saved locally in 'vdev->pm_save').
+	 * Some of the devices in the bus can be in the runtime suspended
+	 * state. Increment the usage count for all the devices in the dev_set
+	 * before reset and decrement the same after reset.
 	 */
-	list_for_each_entry(cur, &dev_set->device_list, vdev.dev_set_list)
-		vfio_pci_set_power_state(cur, PCI_D0);
+	if (!disable_idle_d3 && vfio_pci_dev_set_pm_runtime_get(dev_set))
+		return;
 
-	ret = pci_reset_bus(pdev);
-	if (ret)
-		return false;
+	if (!pci_reset_bus(pdev))
+		reset_done = true;
 
 	list_for_each_entry(cur, &dev_set->device_list, vdev.dev_set_list) {
-		cur->needs_reset = false;
+		if (reset_done)
+			cur->needs_reset = false;
+
 		if (!disable_idle_d3)
-			vfio_pci_set_power_state(cur, PCI_D3hot);
+			pm_runtime_put(&cur->pdev->dev);
 	}
-	return true;
 }
 
 void vfio_pci_core_set_params(bool is_nointxmask, bool is_disable_vga,
diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c
index 6a7fb8d91aaa..61e71c1154be 100644
--- a/drivers/vfio/vfio.c
+++ b/drivers/vfio/vfio.c
@@ -66,18 +66,18 @@ struct vfio_group {
 	struct device 			dev;
 	struct cdev			cdev;
 	refcount_t			users;
-	atomic_t			container_users;
+	unsigned int			container_users;
 	struct iommu_group		*iommu_group;
 	struct vfio_container		*container;
 	struct list_head		device_list;
 	struct mutex			device_lock;
 	struct list_head		vfio_next;
 	struct list_head		container_next;
-	atomic_t			opened;
-	wait_queue_head_t		container_q;
 	enum vfio_group_type		type;
 	unsigned int			dev_counter;
+	struct rw_semaphore		group_rwsem;
 	struct kvm			*kvm;
+	struct file			*opened_file;
 	struct blocking_notifier_head	notifier;
 };
 
@@ -361,9 +361,9 @@ static struct vfio_group *vfio_group_alloc(struct iommu_group *iommu_group,
 	group->cdev.owner = THIS_MODULE;
 
 	refcount_set(&group->users, 1);
+	init_rwsem(&group->group_rwsem);
 	INIT_LIST_HEAD(&group->device_list);
 	mutex_init(&group->device_lock);
-	init_waitqueue_head(&group->container_q);
 	group->iommu_group = iommu_group;
 	/* put in vfio_group_release() */
 	iommu_group_ref_get(iommu_group);
@@ -429,7 +429,7 @@ static void vfio_group_put(struct vfio_group *group)
 	 * properly hold the group reference.
 	 */
 	WARN_ON(!list_empty(&group->device_list));
-	WARN_ON(atomic_read(&group->container_users));
+	WARN_ON(group->container || group->container_users);
 	WARN_ON(group->notifier.head);
 
 	list_del(&group->vfio_next);
@@ -444,31 +444,15 @@ static void vfio_group_get(struct vfio_group *group)
 	refcount_inc(&group->users);
 }
 
-static struct vfio_group *vfio_group_get_from_dev(struct device *dev)
-{
-	struct iommu_group *iommu_group;
-	struct vfio_group *group;
-
-	iommu_group = iommu_group_get(dev);
-	if (!iommu_group)
-		return NULL;
-
-	group = vfio_group_get_from_iommu(iommu_group);
-	iommu_group_put(iommu_group);
-
-	return group;
-}
-
 /*
  * Device objects - create, release, get, put, search
  */
 /* Device reference always implies a group reference */
-void vfio_device_put(struct vfio_device *device)
+static void vfio_device_put(struct vfio_device *device)
 {
 	if (refcount_dec_and_test(&device->refcount))
 		complete(&device->comp);
 }
-EXPORT_SYMBOL_GPL(vfio_device_put);
 
 static bool vfio_device_try_get(struct vfio_device *device)
 {
@@ -547,11 +531,11 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
 
 	iommu_group = iommu_group_get(dev);
 #ifdef CONFIG_VFIO_NOIOMMU
-	if (!iommu_group && noiommu && !iommu_present(dev->bus)) {
+	if (!iommu_group && noiommu) {
 		/*
 		 * With noiommu enabled, create an IOMMU group for devices that
-		 * don't already have one and don't have an iommu_ops on their
-		 * bus.  Taint the kernel because we're about to give a DMA
+		 * don't already have one, implying no IOMMU hardware/driver
+		 * exists.  Taint the kernel because we're about to give a DMA
 		 * capable device to a user without IOMMU protection.
 		 */
 		group = vfio_noiommu_group_alloc(dev, VFIO_NO_IOMMU);
@@ -640,29 +624,6 @@ int vfio_register_emulated_iommu_dev(struct vfio_device *device)
 }
 EXPORT_SYMBOL_GPL(vfio_register_emulated_iommu_dev);
 
-/*
- * Get a reference to the vfio_device for a device.  Even if the
- * caller thinks they own the device, they could be racing with a
- * release call path, so we can't trust drvdata for the shortcut.
- * Go the long way around, from the iommu_group to the vfio_group
- * to the vfio_device.
- */
-struct vfio_device *vfio_device_get_from_dev(struct device *dev)
-{
-	struct vfio_group *group;
-	struct vfio_device *device;
-
-	group = vfio_group_get_from_dev(dev);
-	if (!group)
-		return NULL;
-
-	device = vfio_group_get_device(group, dev);
-	vfio_group_put(group);
-
-	return device;
-}
-EXPORT_SYMBOL_GPL(vfio_device_get_from_dev);
-
 static struct vfio_device *vfio_device_get_from_name(struct vfio_group *group,
 						     char *buf)
 {
@@ -730,23 +691,6 @@ void vfio_unregister_group_dev(struct vfio_device *device)
 	group->dev_counter--;
 	mutex_unlock(&group->device_lock);
 
-	/*
-	 * In order to support multiple devices per group, devices can be
-	 * plucked from the group while other devices in the group are still
-	 * in use.  The container persists with this group and those remaining
-	 * devices still attached.  If the user creates an isolation violation
-	 * by binding this device to another driver while the group is still in
-	 * use, that's their fault.  However, in the case of removing the last,
-	 * or potentially the only, device in the group there can be no other
-	 * in-use devices in the group.  The user has done their due diligence
-	 * and we should lay no claims to those devices.  In order to do that,
-	 * we need to make sure the group is detached from the container.
-	 * Without this stall, we're potentially racing with a user process
-	 * that may attempt to immediately bind this device to another driver.
-	 */
-	if (list_empty(&group->device_list))
-		wait_event(group->container_q, !group->container);
-
 	if (group->type == VFIO_NO_IOMMU || group->type == VFIO_EMULATED_IOMMU)
 		iommu_group_remove_device(device->dev);
 
@@ -981,6 +925,8 @@ static void __vfio_group_unset_container(struct vfio_group *group)
 	struct vfio_container *container = group->container;
 	struct vfio_iommu_driver *driver;
 
+	lockdep_assert_held_write(&group->group_rwsem);
+
 	down_write(&container->group_lock);
 
 	driver = container->iommu_driver;
@@ -988,10 +934,11 @@ static void __vfio_group_unset_container(struct vfio_group *group)
 		driver->ops->detach_group(container->iommu_data,
 					  group->iommu_group);
 
-	iommu_group_release_dma_owner(group->iommu_group);
+	if (group->type == VFIO_IOMMU)
+		iommu_group_release_dma_owner(group->iommu_group);
 
 	group->container = NULL;
-	wake_up(&group->container_q);
+	group->container_users = 0;
 	list_del(&group->container_next);
 
 	/* Detaching the last group deprivileges a container, remove iommu */
@@ -1015,30 +962,16 @@ static void __vfio_group_unset_container(struct vfio_group *group)
  */
 static int vfio_group_unset_container(struct vfio_group *group)
 {
-	int users = atomic_cmpxchg(&group->container_users, 1, 0);
+	lockdep_assert_held_write(&group->group_rwsem);
 
-	if (!users)
+	if (!group->container)
 		return -EINVAL;
-	if (users != 1)
+	if (group->container_users != 1)
 		return -EBUSY;
-
 	__vfio_group_unset_container(group);
-
 	return 0;
 }
 
-/*
- * When removing container users, anything that removes the last user
- * implicitly removes the group from the container.  That is, if the
- * group file descriptor is closed, as well as any device file descriptors,
- * the group is free.
- */
-static void vfio_group_try_dissolve_container(struct vfio_group *group)
-{
-	if (0 == atomic_dec_if_positive(&group->container_users))
-		__vfio_group_unset_container(group);
-}
-
 static int vfio_group_set_container(struct vfio_group *group, int container_fd)
 {
 	struct fd f;
@@ -1046,7 +979,9 @@ static int vfio_group_set_container(struct vfio_group *group, int container_fd)
 	struct vfio_iommu_driver *driver;
 	int ret = 0;
 
-	if (atomic_read(&group->container_users))
+	lockdep_assert_held_write(&group->group_rwsem);
+
+	if (group->container || WARN_ON(group->container_users))
 		return -EINVAL;
 
 	if (group->type == VFIO_NO_IOMMU && !capable(CAP_SYS_RAWIO))
@@ -1074,9 +1009,11 @@ static int vfio_group_set_container(struct vfio_group *group, int container_fd)
 		goto unlock_out;
 	}
 
-	ret = iommu_group_claim_dma_owner(group->iommu_group, f.file);
-	if (ret)
-		goto unlock_out;
+	if (group->type == VFIO_IOMMU) {
+		ret = iommu_group_claim_dma_owner(group->iommu_group, f.file);
+		if (ret)
+			goto unlock_out;
+	}
 
 	driver = container->iommu_driver;
 	if (driver) {
@@ -1084,18 +1021,20 @@ static int vfio_group_set_container(struct vfio_group *group, int container_fd)
 						group->iommu_group,
 						group->type);
 		if (ret) {
-			iommu_group_release_dma_owner(group->iommu_group);
+			if (group->type == VFIO_IOMMU)
+				iommu_group_release_dma_owner(
+					group->iommu_group);
 			goto unlock_out;
 		}
 	}
 
 	group->container = container;
+	group->container_users = 1;
 	container->noiommu = (group->type == VFIO_NO_IOMMU);
 	list_add(&group->container_next, &container->group_list);
 
 	/* Get a reference on the container and mark a user within the group */
 	vfio_container_get(container);
-	atomic_inc(&group->container_users);
 
 unlock_out:
 	up_write(&container->group_lock);
@@ -1103,54 +1042,74 @@ unlock_out:
 	return ret;
 }
 
-static int vfio_group_add_container_user(struct vfio_group *group)
+static const struct file_operations vfio_device_fops;
+
+/* true if the vfio_device has open_device() called but not close_device() */
+static bool vfio_assert_device_open(struct vfio_device *device)
 {
-	if (!atomic_inc_not_zero(&group->container_users))
+	return !WARN_ON_ONCE(!READ_ONCE(device->open_count));
+}
+
+static int vfio_device_assign_container(struct vfio_device *device)
+{
+	struct vfio_group *group = device->group;
+
+	lockdep_assert_held_write(&group->group_rwsem);
+
+	if (!group->container || !group->container->iommu_driver ||
+	    WARN_ON(!group->container_users))
 		return -EINVAL;
 
-	if (group->type == VFIO_NO_IOMMU) {
-		atomic_dec(&group->container_users);
+	if (group->type == VFIO_NO_IOMMU && !capable(CAP_SYS_RAWIO))
 		return -EPERM;
-	}
-	if (!group->container->iommu_driver) {
-		atomic_dec(&group->container_users);
-		return -EINVAL;
-	}
 
+	get_file(group->opened_file);
+	group->container_users++;
 	return 0;
 }
 
-static const struct file_operations vfio_device_fops;
+static void vfio_device_unassign_container(struct vfio_device *device)
+{
+	down_write(&device->group->group_rwsem);
+	WARN_ON(device->group->container_users <= 1);
+	device->group->container_users--;
+	fput(device->group->opened_file);
+	up_write(&device->group->group_rwsem);
+}
 
-static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
+static struct file *vfio_device_open(struct vfio_device *device)
 {
-	struct vfio_device *device;
 	struct file *filep;
-	int fdno;
-	int ret = 0;
-
-	if (0 == atomic_read(&group->container_users) ||
-	    !group->container->iommu_driver)
-		return -EINVAL;
-
-	if (group->type == VFIO_NO_IOMMU && !capable(CAP_SYS_RAWIO))
-		return -EPERM;
+	int ret;
 
-	device = vfio_device_get_from_name(group, buf);
-	if (IS_ERR(device))
-		return PTR_ERR(device);
+	down_write(&device->group->group_rwsem);
+	ret = vfio_device_assign_container(device);
+	up_write(&device->group->group_rwsem);
+	if (ret)
+		return ERR_PTR(ret);
 
 	if (!try_module_get(device->dev->driver->owner)) {
 		ret = -ENODEV;
-		goto err_device_put;
+		goto err_unassign_container;
 	}
 
 	mutex_lock(&device->dev_set->lock);
 	device->open_count++;
-	if (device->open_count == 1 && device->ops->open_device) {
-		ret = device->ops->open_device(device);
-		if (ret)
-			goto err_undo_count;
+	if (device->open_count == 1) {
+		/*
+		 * Here we pass the KVM pointer with the group under the read
+		 * lock.  If the device driver will use it, it must obtain a
+		 * reference and release it during close_device.
+		 */
+		down_read(&device->group->group_rwsem);
+		device->kvm = device->group->kvm;
+
+		if (device->ops->open_device) {
+			ret = device->ops->open_device(device);
+			if (ret)
+				goto err_undo_count;
+		}
+		up_read(&device->group->group_rwsem);
 	}
 	mutex_unlock(&device->dev_set->lock);
 
@@ -1158,15 +1117,11 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
 	 * We can't use anon_inode_getfd() because we need to modify
 	 * the f_mode flags directly to allow more than just ioctls
 	 */
-	fdno = ret = get_unused_fd_flags(O_CLOEXEC);
-	if (ret < 0)
-		goto err_close_device;
-
 	filep = anon_inode_getfile("[vfio-device]", &vfio_device_fops,
 				   device, O_RDWR);
 	if (IS_ERR(filep)) {
 		ret = PTR_ERR(filep);
-		goto err_fd;
+		goto err_close_device;
 	}
 
 	/*
@@ -1176,26 +1131,61 @@ static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
 	 */
 	filep->f_mode |= (FMODE_LSEEK | FMODE_PREAD | FMODE_PWRITE);
 
-	atomic_inc(&group->container_users);
-
-	fd_install(fdno, filep);
-
-	if (group->type == VFIO_NO_IOMMU)
+	if (device->group->type == VFIO_NO_IOMMU)
 		dev_warn(device->dev, "vfio-noiommu device opened by user "
 			 "(%s:%d)\n", current->comm, task_pid_nr(current));
-	return fdno;
+	/*
+	 * On success the ref of device is moved to the file and
+	 * put in vfio_device_fops_release()
+	 */
+	return filep;
 
-err_fd:
-	put_unused_fd(fdno);
 err_close_device:
 	mutex_lock(&device->dev_set->lock);
+	down_read(&device->group->group_rwsem);
 	if (device->open_count == 1 && device->ops->close_device)
 		device->ops->close_device(device);
 err_undo_count:
 	device->open_count--;
+	if (device->open_count == 0 && device->kvm)
+		device->kvm = NULL;
+	up_read(&device->group->group_rwsem);
 	mutex_unlock(&device->dev_set->lock);
 	module_put(device->dev->driver->owner);
-err_device_put:
+err_unassign_container:
+	vfio_device_unassign_container(device);
+	return ERR_PTR(ret);
+}
+
+static int vfio_group_get_device_fd(struct vfio_group *group, char *buf)
+{
+	struct vfio_device *device;
+	struct file *filep;
+	int fdno;
+	int ret;
+
+	device = vfio_device_get_from_name(group, buf);
+	if (IS_ERR(device))
+		return PTR_ERR(device);
+
+	fdno = get_unused_fd_flags(O_CLOEXEC);
+	if (fdno < 0) {
+		ret = fdno;
+		goto err_put_device;
+	}
+
+	filep = vfio_device_open(device);
+	if (IS_ERR(filep)) {
+		ret = PTR_ERR(filep);
+		goto err_put_fdno;
+	}
+
+	fd_install(fdno, filep);
+	return fdno;
+
+err_put_fdno:
+	put_unused_fd(fdno);
+err_put_device:
 	vfio_device_put(device);
 	return ret;
 }
@@ -1222,11 +1212,13 @@ static long vfio_group_fops_unl_ioctl(struct file *filep,
 
 		status.flags = 0;
 
+		down_read(&group->group_rwsem);
 		if (group->container)
 			status.flags |= VFIO_GROUP_FLAGS_CONTAINER_SET |
 					VFIO_GROUP_FLAGS_VIABLE;
 		else if (!iommu_group_dma_owner_claimed(group->iommu_group))
 			status.flags |= VFIO_GROUP_FLAGS_VIABLE;
+		up_read(&group->group_rwsem);
 
 		if (copy_to_user((void __user *)arg, &status, minsz))
 			return -EFAULT;
@@ -1244,11 +1236,15 @@ static long vfio_group_fops_unl_ioctl(struct file *filep,
 		if (fd < 0)
 			return -EINVAL;
 
+		down_write(&group->group_rwsem);
 		ret = vfio_group_set_container(group, fd);
+		up_write(&group->group_rwsem);
 		break;
 	}
 	case VFIO_GROUP_UNSET_CONTAINER:
+		down_write(&group->group_rwsem);
 		ret = vfio_group_unset_container(group);
+		up_write(&group->group_rwsem);
 		break;
 	case VFIO_GROUP_GET_DEVICE_FD:
 	{
@@ -1271,38 +1267,38 @@ static int vfio_group_fops_open(struct inode *inode, struct file *filep)
 {
 	struct vfio_group *group =
 		container_of(inode->i_cdev, struct vfio_group, cdev);
-	int opened;
+	int ret;
 
-	/* users can be zero if this races with vfio_group_put() */
-	if (!refcount_inc_not_zero(&group->users))
-		return -ENODEV;
+	down_write(&group->group_rwsem);
 
-	if (group->type == VFIO_NO_IOMMU && !capable(CAP_SYS_RAWIO)) {
-		vfio_group_put(group);
-		return -EPERM;
+	/* users can be zero if this races with vfio_group_put() */
+	if (!refcount_inc_not_zero(&group->users)) {
+		ret = -ENODEV;
+		goto err_unlock;
 	}
 
-	/* Do we need multiple instances of the group open?  Seems not. */
-	opened = atomic_cmpxchg(&group->opened, 0, 1);
-	if (opened) {
-		vfio_group_put(group);
-		return -EBUSY;
+	if (group->type == VFIO_NO_IOMMU && !capable(CAP_SYS_RAWIO)) {
+		ret = -EPERM;
+		goto err_put;
 	}
 
-	/* Is something still in use from a previous open? */
-	if (group->container) {
-		atomic_dec(&group->opened);
-		vfio_group_put(group);
-		return -EBUSY;
+	/*
+	 * Do we need multiple instances of the group open?  Seems not.
+	 */
+	if (group->opened_file) {
+		ret = -EBUSY;
+		goto err_put;
 	}
-
-	/* Warn if previous user didn't cleanup and re-init to drop them */
-	if (WARN_ON(group->notifier.head))
-		BLOCKING_INIT_NOTIFIER_HEAD(&group->notifier);
-
+	group->opened_file = filep;
 	filep->private_data = group;
 
+	up_write(&group->group_rwsem);
 	return 0;
+err_put:
+	vfio_group_put(group);
+err_unlock:
+	up_write(&group->group_rwsem);
+	return ret;
 }
 
 static int vfio_group_fops_release(struct inode *inode, struct file *filep)
@@ -1311,9 +1307,18 @@ static int vfio_group_fops_release(struct inode *inode, struct file *filep)
 
 	filep->private_data = NULL;
 
-	vfio_group_try_dissolve_container(group);
-
-	atomic_dec(&group->opened);
+	down_write(&group->group_rwsem);
+	/*
+	 * Device FDs hold a group file reference, therefore the group release
+	 * is only called when there are no open devices.
+	 */
+	WARN_ON(group->notifier.head);
+	if (group->container) {
+		WARN_ON(group->container_users != 1);
+		__vfio_group_unset_container(group);
+	}
+	group->opened_file = NULL;
+	up_write(&group->group_rwsem);
 
 	vfio_group_put(group);
 
@@ -1336,13 +1341,19 @@ static int vfio_device_fops_release(struct inode *inode, struct file *filep)
 	struct vfio_device *device = filep->private_data;
 
 	mutex_lock(&device->dev_set->lock);
-	if (!--device->open_count && device->ops->close_device)
+	vfio_assert_device_open(device);
+	down_read(&device->group->group_rwsem);
+	if (device->open_count == 1 && device->ops->close_device)
 		device->ops->close_device(device);
+	up_read(&device->group->group_rwsem);
+	device->open_count--;
+	if (device->open_count == 0)
+		device->kvm = NULL;
 	mutex_unlock(&device->dev_set->lock);
 
 	module_put(device->dev->driver->owner);
 
-	vfio_group_try_dissolve_container(device->group);
+	vfio_device_unassign_container(device);
 
 	vfio_device_put(device);
 
@@ -1691,119 +1702,94 @@ static const struct file_operations vfio_device_fops = {
 	.mmap		= vfio_device_fops_mmap,
 };
 
-/*
- * External user API, exported by symbols to be linked dynamically.
- *
- * The protocol includes:
- *  1. do normal VFIO init operation:
- *	- opening a new container;
- *	- attaching group(s) to it;
- *	- setting an IOMMU driver for a container.
- * When IOMMU is set for a container, all groups in it are
- * considered ready to use by an external user.
+/**
+ * vfio_file_iommu_group - Return the struct iommu_group for the vfio group file
+ * @file: VFIO group file
  *
- * 2. User space passes a group fd to an external user.
- * The external user calls vfio_group_get_external_user()
- * to verify that:
- *	- the group is initialized;
- *	- IOMMU is set for it.
- * If both checks passed, vfio_group_get_external_user()
- * increments the container user counter to prevent
- * the VFIO group from disposal before KVM exits.
- *
- * 3. The external user calls vfio_external_user_iommu_id()
- * to know an IOMMU ID.
- *
- * 4. When the external KVM finishes, it calls
- * vfio_group_put_external_user() to release the VFIO group.
- * This call decrements the container user counter.
+ * The returned iommu_group is valid as long as a ref is held on the file.
  */
-struct vfio_group *vfio_group_get_external_user(struct file *filep)
+struct iommu_group *vfio_file_iommu_group(struct file *file)
 {
-	struct vfio_group *group = filep->private_data;
-	int ret;
-
-	if (filep->f_op != &vfio_group_fops)
-		return ERR_PTR(-EINVAL);
+	struct vfio_group *group = file->private_data;
 
-	ret = vfio_group_add_container_user(group);
-	if (ret)
-		return ERR_PTR(ret);
-
-	/*
-	 * Since the caller holds the fget on the file group->users must be >= 1
-	 */
-	vfio_group_get(group);
-
-	return group;
+	if (file->f_op != &vfio_group_fops)
+		return NULL;
+	return group->iommu_group;
 }
-EXPORT_SYMBOL_GPL(vfio_group_get_external_user);
+EXPORT_SYMBOL_GPL(vfio_file_iommu_group);
 
-/*
- * External user API, exported by symbols to be linked dynamically.
- * The external user passes in a device pointer
- * to verify that:
- *	- A VFIO group is assiciated with the device;
- *	- IOMMU is set for the group.
- * If both checks passed, vfio_group_get_external_user_from_dev()
- * increments the container user counter to prevent the VFIO group
- * from disposal before external user exits and returns the pointer
- * to the VFIO group.
- *
- * When the external user finishes using the VFIO group, it calls
- * vfio_group_put_external_user() to release the VFIO group and
- * decrement the container user counter.
+/**
+ * vfio_file_enforced_coherent - True if the DMA associated with the VFIO file
+ *        is always CPU cache coherent
+ * @file: VFIO group file
  *
- * @dev [in]	: device
- * Return error PTR or pointer to VFIO group.
+ * Enforced coherency means that the IOMMU ignores things like the PCIe no-snoop
+ * bit in DMA transactions. A return of false indicates that the user has
+ * rights to access additional instructions such as wbinvd on x86.
  */
-
-struct vfio_group *vfio_group_get_external_user_from_dev(struct device *dev)
+bool vfio_file_enforced_coherent(struct file *file)
 {
-	struct vfio_group *group;
-	int ret;
+	struct vfio_group *group = file->private_data;
+	bool ret;
 
-	group = vfio_group_get_from_dev(dev);
-	if (!group)
-		return ERR_PTR(-ENODEV);
+	if (file->f_op != &vfio_group_fops)
+		return true;
 
-	ret = vfio_group_add_container_user(group);
-	if (ret) {
-		vfio_group_put(group);
-		return ERR_PTR(ret);
+	down_read(&group->group_rwsem);
+	if (group->container) {
+		ret = vfio_ioctl_check_extension(group->container,
+						 VFIO_DMA_CC_IOMMU);
+	} else {
+		/*
+		 * Since the coherency state is determined only once a container
+		 * is attached the user must do so before they can prove they
+		 * have permission.
+		 */
+		ret = true;
 	}
-
-	return group;
+	up_read(&group->group_rwsem);
+	return ret;
 }
-EXPORT_SYMBOL_GPL(vfio_group_get_external_user_from_dev);
+EXPORT_SYMBOL_GPL(vfio_file_enforced_coherent);
 
-void vfio_group_put_external_user(struct vfio_group *group)
+/**
+ * vfio_file_set_kvm - Link a kvm with VFIO drivers
+ * @file: VFIO group file
+ * @kvm: KVM to link
+ *
+ * When a VFIO device is first opened the KVM will be available in
+ * device->kvm if one was associated with the group.
+ */
+void vfio_file_set_kvm(struct file *file, struct kvm *kvm)
 {
-	vfio_group_try_dissolve_container(group);
-	vfio_group_put(group);
-}
-EXPORT_SYMBOL_GPL(vfio_group_put_external_user);
+	struct vfio_group *group = file->private_data;
 
-bool vfio_external_group_match_file(struct vfio_group *test_group,
-				    struct file *filep)
-{
-	struct vfio_group *group = filep->private_data;
+	if (file->f_op != &vfio_group_fops)
+		return;
 
-	return (filep->f_op == &vfio_group_fops) && (group == test_group);
+	down_write(&group->group_rwsem);
+	group->kvm = kvm;
+	up_write(&group->group_rwsem);
 }
-EXPORT_SYMBOL_GPL(vfio_external_group_match_file);
+EXPORT_SYMBOL_GPL(vfio_file_set_kvm);
 
-int vfio_external_user_iommu_id(struct vfio_group *group)
+/**
+ * vfio_file_has_dev - True if the VFIO file is a handle for device
+ * @file: VFIO file to check
+ * @device: Device that must be part of the file
+ *
+ * Returns true if given file has permission to manipulate the given device.
+ */
+bool vfio_file_has_dev(struct file *file, struct vfio_device *device)
 {
-	return iommu_group_id(group->iommu_group);
-}
-EXPORT_SYMBOL_GPL(vfio_external_user_iommu_id);
+	struct vfio_group *group = file->private_data;
 
-long vfio_external_check_extension(struct vfio_group *group, unsigned long arg)
-{
-	return vfio_ioctl_check_extension(group->container, arg);
+	if (file->f_op != &vfio_group_fops)
+		return false;
+
+	return group == device->group;
 }
-EXPORT_SYMBOL_GPL(vfio_external_check_extension);
+EXPORT_SYMBOL_GPL(vfio_file_has_dev);
 
 /*
  * Sub-module support
@@ -1926,7 +1912,7 @@ EXPORT_SYMBOL(vfio_set_irqs_validate_and_prepare);
 /*
  * Pin a set of guest PFNs and return their associated host PFNs for local
  * domain only.
- * @dev [in]     : device
+ * @device [in]  : device
  * @user_pfn [in]: array of user/guest PFNs to be pinned.
  * @npage [in]   : count of elements in user_pfn array.  This count should not
  *		   be greater VFIO_PIN_PAGES_MAX_ENTRIES.
@@ -1934,33 +1920,25 @@ EXPORT_SYMBOL(vfio_set_irqs_validate_and_prepare);
  * @phys_pfn[out]: array of host PFNs
  * Return error or number of pages pinned.
  */
-int vfio_pin_pages(struct device *dev, unsigned long *user_pfn, int npage,
-		   int prot, unsigned long *phys_pfn)
+int vfio_pin_pages(struct vfio_device *device, unsigned long *user_pfn,
+		   int npage, int prot, unsigned long *phys_pfn)
 {
 	struct vfio_container *container;
-	struct vfio_group *group;
+	struct vfio_group *group = device->group;
 	struct vfio_iommu_driver *driver;
 	int ret;
 
-	if (!dev || !user_pfn || !phys_pfn || !npage)
+	if (!user_pfn || !phys_pfn || !npage ||
+	    !vfio_assert_device_open(device))
 		return -EINVAL;
 
 	if (npage > VFIO_PIN_PAGES_MAX_ENTRIES)
 		return -E2BIG;
 
-	group = vfio_group_get_from_dev(dev);
-	if (!group)
-		return -ENODEV;
-
-	if (group->dev_counter > 1) {
-		ret = -EINVAL;
-		goto err_pin_pages;
-	}
-
-	ret = vfio_group_add_container_user(group);
-	if (ret)
-		goto err_pin_pages;
+	if (group->dev_counter > 1)
+		return -EINVAL;
 
+	/* group->container cannot change while a vfio device is open */
 	container = group->container;
 	driver = container->iommu_driver;
 	if (likely(driver && driver->ops->pin_pages))
@@ -1970,45 +1948,34 @@ int vfio_pin_pages(struct device *dev, unsigned long *user_pfn, int npage,
 	else
 		ret = -ENOTTY;
 
-	vfio_group_try_dissolve_container(group);
-
-err_pin_pages:
-	vfio_group_put(group);
 	return ret;
 }
 EXPORT_SYMBOL(vfio_pin_pages);
 
 /*
  * Unpin set of host PFNs for local domain only.
- * @dev [in]     : device
+ * @device [in]  : device
  * @user_pfn [in]: array of user/guest PFNs to be unpinned. Number of user/guest
  *		   PFNs should not be greater than VFIO_PIN_PAGES_MAX_ENTRIES.
  * @npage [in]   : count of elements in user_pfn array.  This count should not
  *                 be greater than VFIO_PIN_PAGES_MAX_ENTRIES.
  * Return error or number of pages unpinned.
  */
-int vfio_unpin_pages(struct device *dev, unsigned long *user_pfn, int npage)
+int vfio_unpin_pages(struct vfio_device *device, unsigned long *user_pfn,
+		     int npage)
 {
 	struct vfio_container *container;
-	struct vfio_group *group;
 	struct vfio_iommu_driver *driver;
 	int ret;
 
-	if (!dev || !user_pfn || !npage)
+	if (!user_pfn || !npage || !vfio_assert_device_open(device))
 		return -EINVAL;
 
 	if (npage > VFIO_PIN_PAGES_MAX_ENTRIES)
 		return -E2BIG;
 
-	group = vfio_group_get_from_dev(dev);
-	if (!group)
-		return -ENODEV;
-
-	ret = vfio_group_add_container_user(group);
-	if (ret)
-		goto err_unpin_pages;
-
-	container = group->container;
+	/* group->container cannot change while a vfio device is open */
+	container = device->group->container;
 	driver = container->iommu_driver;
 	if (likely(driver && driver->ops->unpin_pages))
 		ret = driver->ops->unpin_pages(container->iommu_data, user_pfn,
@@ -2016,110 +1983,11 @@ int vfio_unpin_pages(struct device *dev, unsigned long *user_pfn, int npage)
 	else
 		ret = -ENOTTY;
 
-	vfio_group_try_dissolve_container(group);
-
-err_unpin_pages:
-	vfio_group_put(group);
 	return ret;
 }
 EXPORT_SYMBOL(vfio_unpin_pages);
 
 /*
- * Pin a set of guest IOVA PFNs and return their associated host PFNs for a
- * VFIO group.
- *
- * The caller needs to call vfio_group_get_external_user() or
- * vfio_group_get_external_user_from_dev() prior to calling this interface,
- * so as to prevent the VFIO group from disposal in the middle of the call.
- * But it can keep the reference to the VFIO group for several calls into
- * this interface.
- * After finishing using of the VFIO group, the caller needs to release the
- * VFIO group by calling vfio_group_put_external_user().
- *
- * @group [in]		: VFIO group
- * @user_iova_pfn [in]	: array of user/guest IOVA PFNs to be pinned.
- * @npage [in]		: count of elements in user_iova_pfn array.
- *			  This count should not be greater
- *			  VFIO_PIN_PAGES_MAX_ENTRIES.
- * @prot [in]		: protection flags
- * @phys_pfn [out]	: array of host PFNs
- * Return error or number of pages pinned.
- */
-int vfio_group_pin_pages(struct vfio_group *group,
-			 unsigned long *user_iova_pfn, int npage,
-			 int prot, unsigned long *phys_pfn)
-{
-	struct vfio_container *container;
-	struct vfio_iommu_driver *driver;
-	int ret;
-
-	if (!group || !user_iova_pfn || !phys_pfn || !npage)
-		return -EINVAL;
-
-	if (group->dev_counter > 1)
-		return -EINVAL;
-
-	if (npage > VFIO_PIN_PAGES_MAX_ENTRIES)
-		return -E2BIG;
-
-	container = group->container;
-	driver = container->iommu_driver;
-	if (likely(driver && driver->ops->pin_pages))
-		ret = driver->ops->pin_pages(container->iommu_data,
-					     group->iommu_group, user_iova_pfn,
-					     npage, prot, phys_pfn);
-	else
-		ret = -ENOTTY;
-
-	return ret;
-}
-EXPORT_SYMBOL(vfio_group_pin_pages);
-
-/*
- * Unpin a set of guest IOVA PFNs for a VFIO group.
- *
- * The caller needs to call vfio_group_get_external_user() or
- * vfio_group_get_external_user_from_dev() prior to calling this interface,
- * so as to prevent the VFIO group from disposal in the middle of the call.
- * But it can keep the reference to the VFIO group for several calls into
- * this interface.
- * After finishing using of the VFIO group, the caller needs to release the
- * VFIO group by calling vfio_group_put_external_user().
- *
- * @group [in]		: vfio group
- * @user_iova_pfn [in]	: array of user/guest IOVA PFNs to be unpinned.
- * @npage [in]		: count of elements in user_iova_pfn array.
- *			  This count should not be greater than
- *			  VFIO_PIN_PAGES_MAX_ENTRIES.
- * Return error or number of pages unpinned.
- */
-int vfio_group_unpin_pages(struct vfio_group *group,
-			   unsigned long *user_iova_pfn, int npage)
-{
-	struct vfio_container *container;
-	struct vfio_iommu_driver *driver;
-	int ret;
-
-	if (!group || !user_iova_pfn || !npage)
-		return -EINVAL;
-
-	if (npage > VFIO_PIN_PAGES_MAX_ENTRIES)
-		return -E2BIG;
-
-	container = group->container;
-	driver = container->iommu_driver;
-	if (likely(driver && driver->ops->unpin_pages))
-		ret = driver->ops->unpin_pages(container->iommu_data,
-					       user_iova_pfn, npage);
-	else
-		ret = -ENOTTY;
-
-	return ret;
-}
-EXPORT_SYMBOL(vfio_group_unpin_pages);
-
-
-/*
  * This interface allows the CPUs to perform some sort of virtual DMA on
  * behalf of the device.
  *
@@ -2129,32 +1997,25 @@ EXPORT_SYMBOL(vfio_group_unpin_pages);
  * As the read/write of user space memory is conducted via the CPUs and is
  * not a real device DMA, it is not necessary to pin the user space memory.
  *
- * The caller needs to call vfio_group_get_external_user() or
- * vfio_group_get_external_user_from_dev() prior to calling this interface,
- * so as to prevent the VFIO group from disposal in the middle of the call.
- * But it can keep the reference to the VFIO group for several calls into
- * this interface.
- * After finishing using of the VFIO group, the caller needs to release the
- * VFIO group by calling vfio_group_put_external_user().
- *
- * @group [in]		: VFIO group
+ * @device [in]		: VFIO device
  * @user_iova [in]	: base IOVA of a user space buffer
  * @data [in]		: pointer to kernel buffer
  * @len [in]		: kernel buffer length
  * @write		: indicate read or write
  * Return error code on failure or 0 on success.
  */
-int vfio_dma_rw(struct vfio_group *group, dma_addr_t user_iova,
-		void *data, size_t len, bool write)
+int vfio_dma_rw(struct vfio_device *device, dma_addr_t user_iova, void *data,
+		size_t len, bool write)
 {
 	struct vfio_container *container;
 	struct vfio_iommu_driver *driver;
 	int ret = 0;
 
-	if (!group || !data || len <= 0)
+	if (!data || len <= 0 || !vfio_assert_device_open(device))
 		return -EINVAL;
 
-	container = group->container;
+	/* group->container cannot change while a vfio device is open */
+	container = device->group->container;
 	driver = container->iommu_driver;
 
 	if (likely(driver && driver->ops->dma_rw))
@@ -2162,7 +2023,6 @@ int vfio_dma_rw(struct vfio_group *group, dma_addr_t user_iova,
 					  user_iova, data, len, write);
 	else
 		ret = -ENOTTY;
-
 	return ret;
 }
 EXPORT_SYMBOL(vfio_dma_rw);
@@ -2175,9 +2035,7 @@ static int vfio_register_iommu_notifier(struct vfio_group *group,
 	struct vfio_iommu_driver *driver;
 	int ret;
 
-	ret = vfio_group_add_container_user(group);
-	if (ret)
-		return -EINVAL;
+	lockdep_assert_held_read(&group->group_rwsem);
 
 	container = group->container;
 	driver = container->iommu_driver;
@@ -2187,8 +2045,6 @@ static int vfio_register_iommu_notifier(struct vfio_group *group,
 	else
 		ret = -ENOTTY;
 
-	vfio_group_try_dissolve_container(group);
-
 	return ret;
 }
 
@@ -2199,9 +2055,7 @@ static int vfio_unregister_iommu_notifier(struct vfio_group *group,
 	struct vfio_iommu_driver *driver;
 	int ret;
 
-	ret = vfio_group_add_container_user(group);
-	if (ret)
-		return -EINVAL;
+	lockdep_assert_held_read(&group->group_rwsem);
 
 	container = group->container;
 	driver = container->iommu_driver;
@@ -2211,147 +2065,52 @@ static int vfio_unregister_iommu_notifier(struct vfio_group *group,
 	else
 		ret = -ENOTTY;
 
-	vfio_group_try_dissolve_container(group);
-
-	return ret;
-}
-
-void vfio_group_set_kvm(struct vfio_group *group, struct kvm *kvm)
-{
-	group->kvm = kvm;
-	blocking_notifier_call_chain(&group->notifier,
-				VFIO_GROUP_NOTIFY_SET_KVM, kvm);
-}
-EXPORT_SYMBOL_GPL(vfio_group_set_kvm);
-
-static int vfio_register_group_notifier(struct vfio_group *group,
-					unsigned long *events,
-					struct notifier_block *nb)
-{
-	int ret;
-	bool set_kvm = false;
-
-	if (*events & VFIO_GROUP_NOTIFY_SET_KVM)
-		set_kvm = true;
-
-	/* clear known events */
-	*events &= ~VFIO_GROUP_NOTIFY_SET_KVM;
-
-	/* refuse to continue if still events remaining */
-	if (*events)
-		return -EINVAL;
-
-	ret = vfio_group_add_container_user(group);
-	if (ret)
-		return -EINVAL;
-
-	ret = blocking_notifier_chain_register(&group->notifier, nb);
-
-	/*
-	 * The attaching of kvm and vfio_group might already happen, so
-	 * here we replay once upon registration.
-	 */
-	if (!ret && set_kvm && group->kvm)
-		blocking_notifier_call_chain(&group->notifier,
-					VFIO_GROUP_NOTIFY_SET_KVM, group->kvm);
-
-	vfio_group_try_dissolve_container(group);
-
 	return ret;
 }
 
-static int vfio_unregister_group_notifier(struct vfio_group *group,
-					 struct notifier_block *nb)
+int vfio_register_notifier(struct vfio_device *device,
+			   enum vfio_notify_type type, unsigned long *events,
+			   struct notifier_block *nb)
 {
+	struct vfio_group *group = device->group;
 	int ret;
 
-	ret = vfio_group_add_container_user(group);
-	if (ret)
-		return -EINVAL;
-
-	ret = blocking_notifier_chain_unregister(&group->notifier, nb);
-
-	vfio_group_try_dissolve_container(group);
-
-	return ret;
-}
-
-int vfio_register_notifier(struct device *dev, enum vfio_notify_type type,
-			   unsigned long *events, struct notifier_block *nb)
-{
-	struct vfio_group *group;
-	int ret;
-
-	if (!dev || !nb || !events || (*events == 0))
+	if (!nb || !events || (*events == 0) ||
+	    !vfio_assert_device_open(device))
 		return -EINVAL;
 
-	group = vfio_group_get_from_dev(dev);
-	if (!group)
-		return -ENODEV;
-
 	switch (type) {
 	case VFIO_IOMMU_NOTIFY:
 		ret = vfio_register_iommu_notifier(group, events, nb);
 		break;
-	case VFIO_GROUP_NOTIFY:
-		ret = vfio_register_group_notifier(group, events, nb);
-		break;
 	default:
 		ret = -EINVAL;
 	}
-
-	vfio_group_put(group);
 	return ret;
 }
 EXPORT_SYMBOL(vfio_register_notifier);
 
-int vfio_unregister_notifier(struct device *dev, enum vfio_notify_type type,
+int vfio_unregister_notifier(struct vfio_device *device,
+			     enum vfio_notify_type type,
 			     struct notifier_block *nb)
 {
-	struct vfio_group *group;
+	struct vfio_group *group = device->group;
 	int ret;
 
-	if (!dev || !nb)
+	if (!nb || !vfio_assert_device_open(device))
 		return -EINVAL;
 
-	group = vfio_group_get_from_dev(dev);
-	if (!group)
-		return -ENODEV;
-
 	switch (type) {
 	case VFIO_IOMMU_NOTIFY:
 		ret = vfio_unregister_iommu_notifier(group, nb);
 		break;
-	case VFIO_GROUP_NOTIFY:
-		ret = vfio_unregister_group_notifier(group, nb);
-		break;
 	default:
 		ret = -EINVAL;
 	}
-
-	vfio_group_put(group);
 	return ret;
 }
 EXPORT_SYMBOL(vfio_unregister_notifier);
 
-struct iommu_domain *vfio_group_iommu_domain(struct vfio_group *group)
-{
-	struct vfio_container *container;
-	struct vfio_iommu_driver *driver;
-
-	if (!group)
-		return ERR_PTR(-EINVAL);
-
-	container = group->container;
-	driver = container->iommu_driver;
-	if (likely(driver && driver->ops->group_iommu_domain))
-		return driver->ops->group_iommu_domain(container->iommu_data,
-						       group->iommu_group);
-
-	return ERR_PTR(-ENOTTY);
-}
-EXPORT_SYMBOL_GPL(vfio_group_iommu_domain);
-
 /*
  * Module/class support
  */
diff --git a/include/linux/mlx5/driver.h b/include/linux/mlx5/driver.h
index b064bc278f52..5040cd774c5a 100644
--- a/include/linux/mlx5/driver.h
+++ b/include/linux/mlx5/driver.h
@@ -447,6 +447,11 @@ struct mlx5_qp_table {
 	struct radix_tree_root	tree;
 };
 
+enum {
+	MLX5_PF_NOTIFY_DISABLE_VF,
+	MLX5_PF_NOTIFY_ENABLE_VF,
+};
+
 struct mlx5_vf_context {
 	int	enabled;
 	u64	port_guid;
@@ -457,6 +462,7 @@ struct mlx5_vf_context {
 	u8	port_guid_valid:1;
 	u8	node_guid_valid:1;
 	enum port_state_policy	policy;
+	struct blocking_notifier_head notifier;
 };
 
 struct mlx5_core_sriov {
@@ -1162,6 +1168,12 @@ int mlx5_dm_sw_icm_dealloc(struct mlx5_core_dev *dev, enum mlx5_sw_icm_type type
 struct mlx5_core_dev *mlx5_vf_get_core_dev(struct pci_dev *pdev);
 void mlx5_vf_put_core_dev(struct mlx5_core_dev *mdev);
 
+int mlx5_sriov_blocking_notifier_register(struct mlx5_core_dev *mdev,
+					  int vf_id,
+					  struct notifier_block *nb);
+void mlx5_sriov_blocking_notifier_unregister(struct mlx5_core_dev *mdev,
+					     int vf_id,
+					     struct notifier_block *nb);
 #ifdef CONFIG_MLX5_CORE_IPOIB
 struct net_device *mlx5_rdma_netdev_alloc(struct mlx5_core_dev *mdev,
 					  struct ib_device *ibdev,
diff --git a/include/linux/vfio.h b/include/linux/vfio.h
index 66dda06ec42d..aa888cc51757 100644
--- a/include/linux/vfio.h
+++ b/include/linux/vfio.h
@@ -15,6 +15,8 @@
 #include <linux/poll.h>
 #include <uapi/linux/vfio.h>
 
+struct kvm;
+
 /*
  * VFIO devices can be placed in a set, this allows all devices to share this
  * structure and the VFIO core will provide a lock that is held around
@@ -34,6 +36,8 @@ struct vfio_device {
 	struct vfio_device_set *dev_set;
 	struct list_head dev_set_list;
 	unsigned int migration_flags;
+	/* Driver must reference the kvm during open_device or never touch it */
+	struct kvm *kvm;
 
 	/* Members below here are private, not for driver use */
 	refcount_t refcount;
@@ -125,8 +129,6 @@ void vfio_uninit_group_dev(struct vfio_device *device);
 int vfio_register_group_dev(struct vfio_device *device);
 int vfio_register_emulated_iommu_dev(struct vfio_device *device);
 void vfio_unregister_group_dev(struct vfio_device *device);
-extern struct vfio_device *vfio_device_get_from_dev(struct device *dev);
-extern void vfio_device_put(struct vfio_device *device);
 
 int vfio_assign_device_set(struct vfio_device *device, void *set_id);
 
@@ -138,56 +140,36 @@ int vfio_mig_get_next_state(struct vfio_device *device,
 /*
  * External user API
  */
-extern struct vfio_group *vfio_group_get_external_user(struct file *filep);
-extern void vfio_group_put_external_user(struct vfio_group *group);
-extern struct vfio_group *vfio_group_get_external_user_from_dev(struct device
-								*dev);
-extern bool vfio_external_group_match_file(struct vfio_group *group,
-					   struct file *filep);
-extern int vfio_external_user_iommu_id(struct vfio_group *group);
-extern long vfio_external_check_extension(struct vfio_group *group,
-					  unsigned long arg);
+extern struct iommu_group *vfio_file_iommu_group(struct file *file);
+extern bool vfio_file_enforced_coherent(struct file *file);
+extern void vfio_file_set_kvm(struct file *file, struct kvm *kvm);
+extern bool vfio_file_has_dev(struct file *file, struct vfio_device *device);
 
 #define VFIO_PIN_PAGES_MAX_ENTRIES	(PAGE_SIZE/sizeof(unsigned long))
 
-extern int vfio_pin_pages(struct device *dev, unsigned long *user_pfn,
+extern int vfio_pin_pages(struct vfio_device *device, unsigned long *user_pfn,
 			  int npage, int prot, unsigned long *phys_pfn);
-extern int vfio_unpin_pages(struct device *dev, unsigned long *user_pfn,
+extern int vfio_unpin_pages(struct vfio_device *device, unsigned long *user_pfn,
 			    int npage);
-
-extern int vfio_group_pin_pages(struct vfio_group *group,
-				unsigned long *user_iova_pfn, int npage,
-				int prot, unsigned long *phys_pfn);
-extern int vfio_group_unpin_pages(struct vfio_group *group,
-				  unsigned long *user_iova_pfn, int npage);
-
-extern int vfio_dma_rw(struct vfio_group *group, dma_addr_t user_iova,
+extern int vfio_dma_rw(struct vfio_device *device, dma_addr_t user_iova,
 		       void *data, size_t len, bool write);
 
-extern struct iommu_domain *vfio_group_iommu_domain(struct vfio_group *group);
-
 /* each type has independent events */
 enum vfio_notify_type {
 	VFIO_IOMMU_NOTIFY = 0,
-	VFIO_GROUP_NOTIFY = 1,
 };
 
 /* events for VFIO_IOMMU_NOTIFY */
 #define VFIO_IOMMU_NOTIFY_DMA_UNMAP	BIT(0)
 
-/* events for VFIO_GROUP_NOTIFY */
-#define VFIO_GROUP_NOTIFY_SET_KVM	BIT(0)
-
-extern int vfio_register_notifier(struct device *dev,
+extern int vfio_register_notifier(struct vfio_device *device,
 				  enum vfio_notify_type type,
 				  unsigned long *required_events,
 				  struct notifier_block *nb);
-extern int vfio_unregister_notifier(struct device *dev,
+extern int vfio_unregister_notifier(struct vfio_device *device,
 				    enum vfio_notify_type type,
 				    struct notifier_block *nb);
 
-struct kvm;
-extern void vfio_group_set_kvm(struct vfio_group *group, struct kvm *kvm);
 
 /*
  * Sub-module helpers
diff --git a/include/linux/vfio_pci_core.h b/include/linux/vfio_pci_core.h
index 48f2dd3c568c..23c176d4b073 100644
--- a/include/linux/vfio_pci_core.h
+++ b/include/linux/vfio_pci_core.h
@@ -227,8 +227,9 @@ void vfio_pci_core_init_device(struct vfio_pci_core_device *vdev,
 int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev);
 void vfio_pci_core_uninit_device(struct vfio_pci_core_device *vdev);
 void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev);
-int vfio_pci_core_sriov_configure(struct pci_dev *pdev, int nr_virtfn);
 extern const struct pci_error_handlers vfio_pci_core_err_handlers;
+int vfio_pci_core_sriov_configure(struct vfio_pci_core_device *vdev,
+				  int nr_virtfn);
 long vfio_pci_core_ioctl(struct vfio_device *core_vdev, unsigned int cmd,
 		unsigned long arg);
 int vfio_pci_core_ioctl_feature(struct vfio_device *device, u32 flags,
diff --git a/include/uapi/linux/vfio.h b/include/uapi/linux/vfio.h
index fea86061b44e..733a1cddde30 100644
--- a/include/uapi/linux/vfio.h
+++ b/include/uapi/linux/vfio.h
@@ -643,7 +643,7 @@ enum {
 };
 
 /**
- * VFIO_DEVICE_GET_PCI_HOT_RESET_INFO - _IORW(VFIO_TYPE, VFIO_BASE + 12,
+ * VFIO_DEVICE_GET_PCI_HOT_RESET_INFO - _IOWR(VFIO_TYPE, VFIO_BASE + 12,
  *					      struct vfio_pci_hot_reset_info)
  *
  * Return: 0 on success, -errno on failure:
@@ -770,7 +770,7 @@ struct vfio_device_ioeventfd {
 #define VFIO_DEVICE_IOEVENTFD		_IO(VFIO_TYPE, VFIO_BASE + 16)
 
 /**
- * VFIO_DEVICE_FEATURE - _IORW(VFIO_TYPE, VFIO_BASE + 17,
+ * VFIO_DEVICE_FEATURE - _IOWR(VFIO_TYPE, VFIO_BASE + 17,
  *			       struct vfio_device_feature)
  *
  * Get, set, or probe feature data of the device.  The feature is selected
diff --git a/virt/kvm/vfio.c b/virt/kvm/vfio.c
index 8fcbc50221c2..ce1b01d02c51 100644
--- a/virt/kvm/vfio.c
+++ b/virt/kvm/vfio.c
@@ -23,7 +23,7 @@
 
 struct kvm_vfio_group {
 	struct list_head node;
-	struct vfio_group *vfio_group;
+	struct file *file;
 };
 
 struct kvm_vfio {
@@ -32,118 +32,61 @@ struct kvm_vfio {
 	bool noncoherent;
 };
 
-static struct vfio_group *kvm_vfio_group_get_external_user(struct file *filep)
+static void kvm_vfio_file_set_kvm(struct file *file, struct kvm *kvm)
 {
-	struct vfio_group *vfio_group;
-	struct vfio_group *(*fn)(struct file *);
+	void (*fn)(struct file *file, struct kvm *kvm);
 
-	fn = symbol_get(vfio_group_get_external_user);
-	if (!fn)
-		return ERR_PTR(-EINVAL);
-
-	vfio_group = fn(filep);
-
-	symbol_put(vfio_group_get_external_user);
-
-	return vfio_group;
-}
-
-static bool kvm_vfio_external_group_match_file(struct vfio_group *group,
-					       struct file *filep)
-{
-	bool ret, (*fn)(struct vfio_group *, struct file *);
-
-	fn = symbol_get(vfio_external_group_match_file);
-	if (!fn)
-		return false;
-
-	ret = fn(group, filep);
-
-	symbol_put(vfio_external_group_match_file);
-
-	return ret;
-}
-
-static void kvm_vfio_group_put_external_user(struct vfio_group *vfio_group)
-{
-	void (*fn)(struct vfio_group *);
-
-	fn = symbol_get(vfio_group_put_external_user);
-	if (!fn)
-		return;
-
-	fn(vfio_group);
-
-	symbol_put(vfio_group_put_external_user);
-}
-
-static void kvm_vfio_group_set_kvm(struct vfio_group *group, struct kvm *kvm)
-{
-	void (*fn)(struct vfio_group *, struct kvm *);
-
-	fn = symbol_get(vfio_group_set_kvm);
+	fn = symbol_get(vfio_file_set_kvm);
 	if (!fn)
 		return;
 
-	fn(group, kvm);
+	fn(file, kvm);
 
-	symbol_put(vfio_group_set_kvm);
+	symbol_put(vfio_file_set_kvm);
 }
 
-static bool kvm_vfio_group_is_coherent(struct vfio_group *vfio_group)
+static bool kvm_vfio_file_enforced_coherent(struct file *file)
 {
-	long (*fn)(struct vfio_group *, unsigned long);
-	long ret;
+	bool (*fn)(struct file *file);
+	bool ret;
 
-	fn = symbol_get(vfio_external_check_extension);
+	fn = symbol_get(vfio_file_enforced_coherent);
 	if (!fn)
 		return false;
 
-	ret = fn(vfio_group, VFIO_DMA_CC_IOMMU);
+	ret = fn(file);
 
-	symbol_put(vfio_external_check_extension);
+	symbol_put(vfio_file_enforced_coherent);
 
-	return ret > 0;
+	return ret;
 }
 
-#ifdef CONFIG_SPAPR_TCE_IOMMU
-static int kvm_vfio_external_user_iommu_id(struct vfio_group *vfio_group)
+static struct iommu_group *kvm_vfio_file_iommu_group(struct file *file)
 {
-	int (*fn)(struct vfio_group *);
-	int ret = -EINVAL;
+	struct iommu_group *(*fn)(struct file *file);
+	struct iommu_group *ret;
 
-	fn = symbol_get(vfio_external_user_iommu_id);
+	fn = symbol_get(vfio_file_iommu_group);
 	if (!fn)
-		return ret;
+		return NULL;
 
-	ret = fn(vfio_group);
+	ret = fn(file);
 
-	symbol_put(vfio_external_user_iommu_id);
+	symbol_put(vfio_file_iommu_group);
 
 	return ret;
 }
 
-static struct iommu_group *kvm_vfio_group_get_iommu_group(
-		struct vfio_group *group)
-{
-	int group_id = kvm_vfio_external_user_iommu_id(group);
-
-	if (group_id < 0)
-		return NULL;
-
-	return iommu_group_get_by_id(group_id);
-}
-
+#ifdef CONFIG_SPAPR_TCE_IOMMU
 static void kvm_spapr_tce_release_vfio_group(struct kvm *kvm,
-		struct vfio_group *vfio_group)
+					     struct kvm_vfio_group *kvg)
 {
-	struct iommu_group *grp = kvm_vfio_group_get_iommu_group(vfio_group);
+	struct iommu_group *grp = kvm_vfio_file_iommu_group(kvg->file);
 
 	if (WARN_ON_ONCE(!grp))
 		return;
 
 	kvm_spapr_tce_release_iommu_group(kvm, grp);
-	iommu_group_put(grp);
 }
 #endif
 
@@ -163,7 +106,7 @@ static void kvm_vfio_update_coherency(struct kvm_device *dev)
 	mutex_lock(&kv->lock);
 
 	list_for_each_entry(kvg, &kv->group_list, node) {
-		if (!kvm_vfio_group_is_coherent(kvg->vfio_group)) {
+		if (!kvm_vfio_file_enforced_coherent(kvg->file)) {
 			noncoherent = true;
 			break;
 		}
@@ -181,149 +124,162 @@ static void kvm_vfio_update_coherency(struct kvm_device *dev)
 	mutex_unlock(&kv->lock);
 }
 
-static int kvm_vfio_set_group(struct kvm_device *dev, long attr, u64 arg)
+static int kvm_vfio_group_add(struct kvm_device *dev, unsigned int fd)
 {
 	struct kvm_vfio *kv = dev->private;
-	struct vfio_group *vfio_group;
 	struct kvm_vfio_group *kvg;
-	int32_t __user *argp = (int32_t __user *)(unsigned long)arg;
-	struct fd f;
-	int32_t fd;
+	struct file *filp;
 	int ret;
 
-	switch (attr) {
-	case KVM_DEV_VFIO_GROUP_ADD:
-		if (get_user(fd, argp))
-			return -EFAULT;
-
-		f = fdget(fd);
-		if (!f.file)
-			return -EBADF;
-
-		vfio_group = kvm_vfio_group_get_external_user(f.file);
-		fdput(f);
+	filp = fget(fd);
+	if (!filp)
+		return -EBADF;
 
-		if (IS_ERR(vfio_group))
-			return PTR_ERR(vfio_group);
-
-		mutex_lock(&kv->lock);
+	/* Ensure the FD is a vfio group FD.*/
+	if (!kvm_vfio_file_iommu_group(filp)) {
+		ret = -EINVAL;
+		goto err_fput;
+	}
 
-		list_for_each_entry(kvg, &kv->group_list, node) {
-			if (kvg->vfio_group == vfio_group) {
-				mutex_unlock(&kv->lock);
-				kvm_vfio_group_put_external_user(vfio_group);
-				return -EEXIST;
-			}
-		}
+	mutex_lock(&kv->lock);
 
-		kvg = kzalloc(sizeof(*kvg), GFP_KERNEL_ACCOUNT);
-		if (!kvg) {
-			mutex_unlock(&kv->lock);
-			kvm_vfio_group_put_external_user(vfio_group);
-			return -ENOMEM;
+	list_for_each_entry(kvg, &kv->group_list, node) {
+		if (kvg->file == filp) {
+			ret = -EEXIST;
+			goto err_unlock;
 		}
+	}
 
-		list_add_tail(&kvg->node, &kv->group_list);
-		kvg->vfio_group = vfio_group;
+	kvg = kzalloc(sizeof(*kvg), GFP_KERNEL_ACCOUNT);
+	if (!kvg) {
+		ret = -ENOMEM;
+		goto err_unlock;
+	}
 
-		kvm_arch_start_assignment(dev->kvm);
+	kvg->file = filp;
+	list_add_tail(&kvg->node, &kv->group_list);
 
-		mutex_unlock(&kv->lock);
+	kvm_arch_start_assignment(dev->kvm);
 
-		kvm_vfio_group_set_kvm(vfio_group, dev->kvm);
+	mutex_unlock(&kv->lock);
 
-		kvm_vfio_update_coherency(dev);
+	kvm_vfio_file_set_kvm(kvg->file, dev->kvm);
+	kvm_vfio_update_coherency(dev);
 
-		return 0;
+	return 0;
+err_unlock:
+	mutex_unlock(&kv->lock);
+err_fput:
+	fput(filp);
+	return ret;
+}
 
-	case KVM_DEV_VFIO_GROUP_DEL:
-		if (get_user(fd, argp))
-			return -EFAULT;
+static int kvm_vfio_group_del(struct kvm_device *dev, unsigned int fd)
+{
+	struct kvm_vfio *kv = dev->private;
+	struct kvm_vfio_group *kvg;
+	struct fd f;
+	int ret;
 
-		f = fdget(fd);
-		if (!f.file)
-			return -EBADF;
+	f = fdget(fd);
+	if (!f.file)
+		return -EBADF;
 
-		ret = -ENOENT;
+	ret = -ENOENT;
 
-		mutex_lock(&kv->lock);
+	mutex_lock(&kv->lock);
 
-		list_for_each_entry(kvg, &kv->group_list, node) {
-			if (!kvm_vfio_external_group_match_file(kvg->vfio_group,
-								f.file))
-				continue;
+	list_for_each_entry(kvg, &kv->group_list, node) {
+		if (kvg->file != f.file)
+			continue;
 
-			list_del(&kvg->node);
-			kvm_arch_end_assignment(dev->kvm);
+		list_del(&kvg->node);
+		kvm_arch_end_assignment(dev->kvm);
 #ifdef CONFIG_SPAPR_TCE_IOMMU
-			kvm_spapr_tce_release_vfio_group(dev->kvm,
-							 kvg->vfio_group);
+		kvm_spapr_tce_release_vfio_group(dev->kvm, kvg);
 #endif
-			kvm_vfio_group_set_kvm(kvg->vfio_group, NULL);
-			kvm_vfio_group_put_external_user(kvg->vfio_group);
-			kfree(kvg);
-			ret = 0;
-			break;
-		}
+		kvm_vfio_file_set_kvm(kvg->file, NULL);
+		fput(kvg->file);
+		kfree(kvg);
+		ret = 0;
+		break;
+	}
 
-		mutex_unlock(&kv->lock);
+	mutex_unlock(&kv->lock);
 
-		fdput(f);
+	fdput(f);
 
-		kvm_vfio_update_coherency(dev);
+	kvm_vfio_update_coherency(dev);
 
-		return ret;
+	return ret;
+}
 
 #ifdef CONFIG_SPAPR_TCE_IOMMU
-	case KVM_DEV_VFIO_GROUP_SET_SPAPR_TCE: {
-		struct kvm_vfio_spapr_tce param;
-		struct kvm_vfio *kv = dev->private;
-		struct vfio_group *vfio_group;
-		struct kvm_vfio_group *kvg;
-		struct fd f;
-		struct iommu_group *grp;
+static int kvm_vfio_group_set_spapr_tce(struct kvm_device *dev,
+					void __user *arg)
+{
+	struct kvm_vfio_spapr_tce param;
+	struct kvm_vfio *kv = dev->private;
+	struct kvm_vfio_group *kvg;
+	struct fd f;
+	int ret;
 
-		if (copy_from_user(&param, (void __user *)arg,
-				sizeof(struct kvm_vfio_spapr_tce)))
-			return -EFAULT;
+	if (copy_from_user(&param, arg, sizeof(struct kvm_vfio_spapr_tce)))
+		return -EFAULT;
 
-		f = fdget(param.groupfd);
-		if (!f.file)
-			return -EBADF;
+	f = fdget(param.groupfd);
+	if (!f.file)
+		return -EBADF;
 
-		vfio_group = kvm_vfio_group_get_external_user(f.file);
-		fdput(f);
+	ret = -ENOENT;
 
-		if (IS_ERR(vfio_group))
-			return PTR_ERR(vfio_group);
+	mutex_lock(&kv->lock);
 
-		grp = kvm_vfio_group_get_iommu_group(vfio_group);
+	list_for_each_entry(kvg, &kv->group_list, node) {
+		struct iommu_group *grp;
+
+		if (kvg->file != f.file)
+			continue;
+
+		grp = kvm_vfio_file_iommu_group(kvg->file);
 		if (WARN_ON_ONCE(!grp)) {
-			kvm_vfio_group_put_external_user(vfio_group);
-			return -EIO;
+			ret = -EIO;
+			goto err_fdput;
 		}
 
-		ret = -ENOENT;
-
-		mutex_lock(&kv->lock);
+		ret = kvm_spapr_tce_attach_iommu_group(dev->kvm, param.tablefd,
+						       grp);
+		break;
+	}
 
-		list_for_each_entry(kvg, &kv->group_list, node) {
-			if (kvg->vfio_group != vfio_group)
-				continue;
+err_fdput:
+	mutex_unlock(&kv->lock);
+	fdput(f);
+	return ret;
+}
+#endif
 
-			ret = kvm_spapr_tce_attach_iommu_group(dev->kvm,
-					param.tablefd, grp);
-			break;
-		}
+static int kvm_vfio_set_group(struct kvm_device *dev, long attr,
+			      void __user *arg)
+{
+	int32_t __user *argp = arg;
+	int32_t fd;
 
-		mutex_unlock(&kv->lock);
+	switch (attr) {
+	case KVM_DEV_VFIO_GROUP_ADD:
+		if (get_user(fd, argp))
+			return -EFAULT;
+		return kvm_vfio_group_add(dev, fd);
 
-		iommu_group_put(grp);
-		kvm_vfio_group_put_external_user(vfio_group);
+	case KVM_DEV_VFIO_GROUP_DEL:
+		if (get_user(fd, argp))
+			return -EFAULT;
+		return kvm_vfio_group_del(dev, fd);
 
-		return ret;
-	}
-#endif /* CONFIG_SPAPR_TCE_IOMMU */
+#ifdef CONFIG_SPAPR_TCE_IOMMU
+	case KVM_DEV_VFIO_GROUP_SET_SPAPR_TCE:
+		return kvm_vfio_group_set_spapr_tce(dev, arg);
+#endif
 	}
 
 	return -ENXIO;
@@ -334,7 +290,8 @@ static int kvm_vfio_set_attr(struct kvm_device *dev,
 {
 	switch (attr->group) {
 	case KVM_DEV_VFIO_GROUP:
-		return kvm_vfio_set_group(dev, attr->attr, attr->addr);
+		return kvm_vfio_set_group(dev, attr->attr,
+					  u64_to_user_ptr(attr->addr));
 	}
 
 	return -ENXIO;
@@ -367,10 +324,10 @@ static void kvm_vfio_destroy(struct kvm_device *dev)
 
 	list_for_each_entry_safe(kvg, tmp, &kv->group_list, node) {
 #ifdef CONFIG_SPAPR_TCE_IOMMU
-		kvm_spapr_tce_release_vfio_group(dev->kvm, kvg->vfio_group);
+		kvm_spapr_tce_release_vfio_group(dev->kvm, kvg);
 #endif
-		kvm_vfio_group_set_kvm(kvg->vfio_group, NULL);
-		kvm_vfio_group_put_external_user(kvg->vfio_group);
+		kvm_vfio_file_set_kvm(kvg->file, NULL);
+		fput(kvg->file);
 		list_del(&kvg->node);
 		kfree(kvg);
 		kvm_arch_end_assignment(dev->kvm);