diff --git a/drivers/iommu/amd_iommu_v2.c b/drivers/iommu/amd_iommu_v2.c index 3465faf1809e..f7b875bb70d4 100644 --- a/drivers/iommu/amd_iommu_v2.c +++ b/drivers/iommu/amd_iommu_v2.c @@ -132,11 +132,19 @@ static struct device_state *get_device_state(u16 devid) static void free_device_state(struct device_state *dev_state) { + struct iommu_group *group; + /* * First detach device from domain - No more PRI requests will arrive * from that device after it is unbound from the IOMMUv2 domain. */ - iommu_detach_device(dev_state->domain, &dev_state->pdev->dev); + group = iommu_group_get(&dev_state->pdev->dev); + if (WARN_ON(!group)) + return; + + iommu_detach_group(dev_state->domain, group); + + iommu_group_put(group); /* Everything is down now, free the IOMMUv2 domain */ iommu_domain_free(dev_state->domain); @@ -731,6 +739,7 @@ EXPORT_SYMBOL(amd_iommu_unbind_pasid); int amd_iommu_init_device(struct pci_dev *pdev, int pasids) { struct device_state *dev_state; + struct iommu_group *group; unsigned long flags; int ret, tmp; u16 devid; @@ -776,10 +785,16 @@ int amd_iommu_init_device(struct pci_dev *pdev, int pasids) if (ret) goto out_free_domain; - ret = iommu_attach_device(dev_state->domain, &pdev->dev); - if (ret != 0) + group = iommu_group_get(&pdev->dev); + if (!group) goto out_free_domain; + ret = iommu_attach_group(dev_state->domain, group); + if (ret != 0) + goto out_drop_group; + + iommu_group_put(group); + spin_lock_irqsave(&state_lock, flags); if (__get_device_state(devid) != NULL) { @@ -794,6 +809,9 @@ int amd_iommu_init_device(struct pci_dev *pdev, int pasids) return 0; +out_drop_group: + iommu_group_put(group); + out_free_domain: iommu_domain_free(dev_state->domain);