diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c index b19e8c0f48fa..98054497d343 100644 --- a/drivers/iommu/amd/iommu.c +++ b/drivers/iommu/amd/iommu.c @@ -2203,6 +2203,9 @@ static struct iommu_device *amd_iommu_probe_device(struct device *dev) iommu_completion_wait(iommu); + if (dev_is_pci(dev)) + pci_prepare_ats(to_pci_dev(dev), PAGE_SHIFT); + return iommu_dev; } diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c index a31460f9f3d4..9bc50bded5af 100644 --- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c +++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c @@ -3295,6 +3295,12 @@ static struct iommu_device *arm_smmu_probe_device(struct device *dev) smmu->features & ARM_SMMU_FEAT_STALL_FORCE) master->stall_enabled = true; + if (dev_is_pci(dev)) { + unsigned int stu = __ffs(smmu->pgsize_bitmap); + + pci_prepare_ats(to_pci_dev(dev), stu); + } + return &smmu->iommu; err_free_master: diff --git a/drivers/iommu/intel/iommu.c b/drivers/iommu/intel/iommu.c index 9ff8b83c19a3..ad81db026ab2 100644 --- a/drivers/iommu/intel/iommu.c +++ b/drivers/iommu/intel/iommu.c @@ -4091,6 +4091,7 @@ static struct iommu_device *intel_iommu_probe_device(struct device *dev) dev_iommu_priv_set(dev, info); if (pdev && pci_ats_supported(pdev)) { + pci_prepare_ats(pdev, VTD_PAGE_SHIFT); ret = device_rbtree_insert(iommu, info); if (ret) goto free; diff --git a/drivers/pci/ats.c b/drivers/pci/ats.c index c570892b2090..87fa03540b8a 100644 --- a/drivers/pci/ats.c +++ b/drivers/pci/ats.c @@ -47,6 +47,39 @@ bool pci_ats_supported(struct pci_dev *dev) } EXPORT_SYMBOL_GPL(pci_ats_supported); +/** + * pci_prepare_ats - Setup the PS for ATS + * @dev: the PCI device + * @ps: the IOMMU page shift + * + * This must be done by the IOMMU driver on the PF before any VFs are created to + * ensure that the VF can have ATS enabled. + * + * Returns 0 on success, or negative on failure. + */ +int pci_prepare_ats(struct pci_dev *dev, int ps) +{ + u16 ctrl; + + if (!pci_ats_supported(dev)) + return -EINVAL; + + if (WARN_ON(dev->ats_enabled)) + return -EBUSY; + + if (ps < PCI_ATS_MIN_STU) + return -EINVAL; + + if (dev->is_virtfn) + return 0; + + dev->ats_stu = ps; + ctrl = PCI_ATS_CTRL_STU(dev->ats_stu - PCI_ATS_MIN_STU); + pci_write_config_word(dev, dev->ats_cap + PCI_ATS_CTRL, ctrl); + return 0; +} +EXPORT_SYMBOL_GPL(pci_prepare_ats); + /** * pci_enable_ats - enable the ATS capability * @dev: the PCI device diff --git a/include/linux/pci-ats.h b/include/linux/pci-ats.h index df54cd5b15db..0e8b74e63767 100644 --- a/include/linux/pci-ats.h +++ b/include/linux/pci-ats.h @@ -8,6 +8,7 @@ /* Address Translation Service */ bool pci_ats_supported(struct pci_dev *dev); int pci_enable_ats(struct pci_dev *dev, int ps); +int pci_prepare_ats(struct pci_dev *dev, int ps); void pci_disable_ats(struct pci_dev *dev); int pci_ats_queue_depth(struct pci_dev *dev); int pci_ats_page_aligned(struct pci_dev *dev); @@ -16,6 +17,8 @@ static inline bool pci_ats_supported(struct pci_dev *d) { return false; } static inline int pci_enable_ats(struct pci_dev *d, int ps) { return -ENODEV; } +static inline int pci_prepare_ats(struct pci_dev *dev, int ps) +{ return -ENODEV; } static inline void pci_disable_ats(struct pci_dev *d) { } static inline int pci_ats_queue_depth(struct pci_dev *d) { return -ENODEV; }