diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c index 5877abd9b693..6cfe7799dc8c 100644 --- a/drivers/iommu/iommu.c +++ b/drivers/iommu/iommu.c @@ -174,6 +174,36 @@ static void dev_iommu_free(struct device *dev) dev->iommu = NULL; } +static int __iommu_probe_device(struct device *dev) +{ + const struct iommu_ops *ops = dev->bus->iommu_ops; + struct iommu_device *iommu_dev; + struct iommu_group *group; + int ret; + + iommu_dev = ops->probe_device(dev); + if (IS_ERR(iommu_dev)) + return PTR_ERR(iommu_dev); + + dev->iommu->iommu_dev = iommu_dev; + + group = iommu_group_get_for_dev(dev); + if (!IS_ERR(group)) { + ret = PTR_ERR(group); + goto out_release; + } + iommu_group_put(group); + + iommu_device_link(iommu_dev, dev); + + return 0; + +out_release: + ops->release_device(dev); + + return ret; +} + int iommu_probe_device(struct device *dev) { const struct iommu_ops *ops = dev->bus->iommu_ops; @@ -191,10 +221,17 @@ int iommu_probe_device(struct device *dev) goto err_free_dev_param; } - ret = ops->add_device(dev); + if (ops->probe_device) + ret = __iommu_probe_device(dev); + else + ret = ops->add_device(dev); + if (ret) goto err_module_put; + if (ops->probe_finalize) + ops->probe_finalize(dev); + return 0; err_module_put: @@ -204,17 +241,31 @@ err_free_dev_param: return ret; } +static void __iommu_release_device(struct device *dev) +{ + const struct iommu_ops *ops = dev->bus->iommu_ops; + + iommu_device_unlink(dev->iommu->iommu_dev, dev); + + iommu_group_remove_device(dev); + + ops->release_device(dev); +} + void iommu_release_device(struct device *dev) { const struct iommu_ops *ops = dev->bus->iommu_ops; - if (dev->iommu_group) + if (!dev->iommu) + return; + + if (ops->release_device) + __iommu_release_device(dev); + else if (dev->iommu_group) ops->remove_device(dev); - if (dev->iommu) { - module_put(ops->owner); - dev_iommu_free(dev); - } + module_put(ops->owner); + dev_iommu_free(dev); } static struct iommu_domain *__iommu_domain_alloc(struct bus_type *bus, diff --git a/include/linux/iommu.h b/include/linux/iommu.h index 1f027b07e499..30170d191e5e 100644 --- a/include/linux/iommu.h +++ b/include/linux/iommu.h @@ -225,6 +225,10 @@ struct iommu_iotlb_gather { * @iova_to_phys: translate iova to physical address * @add_device: add device to iommu grouping * @remove_device: remove device from iommu grouping + * @probe_device: Add device to iommu driver handling + * @release_device: Remove device from iommu driver handling + * @probe_finalize: Do final setup work after the device is added to an IOMMU + * group and attached to the groups domain * @device_group: find iommu group for a particular device * @domain_get_attr: Query domain attributes * @domain_set_attr: Change domain attributes @@ -275,6 +279,9 @@ struct iommu_ops { phys_addr_t (*iova_to_phys)(struct iommu_domain *domain, dma_addr_t iova); int (*add_device)(struct device *dev); void (*remove_device)(struct device *dev); + struct iommu_device *(*probe_device)(struct device *dev); + void (*release_device)(struct device *dev); + void (*probe_finalize)(struct device *dev); struct iommu_group *(*device_group)(struct device *dev); int (*domain_get_attr)(struct iommu_domain *domain, enum iommu_attr attr, void *data); @@ -375,6 +382,7 @@ struct iommu_fault_param { * * @fault_param: IOMMU detected device fault reporting data * @fwspec: IOMMU fwspec data + * @iommu_dev: IOMMU device this device is linked to * @priv: IOMMU Driver private data * * TODO: migrate other per device data pointers under iommu_dev_data, e.g. @@ -384,6 +392,7 @@ struct dev_iommu { struct mutex lock; struct iommu_fault_param *fault_param; struct iommu_fwspec *fwspec; + struct iommu_device *iommu_dev; void *priv; };