diff --git a/drivers/iommu/amd_iommu_v2.c b/drivers/iommu/amd_iommu_v2.c index 90d734bbf467..9c0d6e290097 100644 --- a/drivers/iommu/amd_iommu_v2.c +++ b/drivers/iommu/amd_iommu_v2.c @@ -513,45 +513,67 @@ static void finish_pri_tag(struct device_state *dev_state, spin_unlock_irqrestore(&pasid_state->lock, flags); } +static void handle_fault_error(struct fault *fault) +{ + int status; + + if (!fault->dev_state->inv_ppr_cb) { + set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); + return; + } + + status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, + fault->pasid, + fault->address, + fault->flags); + switch (status) { + case AMD_IOMMU_INV_PRI_RSP_SUCCESS: + set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); + break; + case AMD_IOMMU_INV_PRI_RSP_INVALID: + set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); + break; + case AMD_IOMMU_INV_PRI_RSP_FAIL: + set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); + break; + default: + BUG(); + } +} + static void do_fault(struct work_struct *work) { struct fault *fault = container_of(work, struct fault, work); - int npages, write; - struct page *page; + struct mm_struct *mm; + struct vm_area_struct *vma; + u64 address; + int ret, write; write = !!(fault->flags & PPR_FAULT_WRITE); - down_read(&fault->state->mm->mmap_sem); - npages = get_user_pages(NULL, fault->state->mm, - fault->address, 1, write, 0, &page, NULL); - up_read(&fault->state->mm->mmap_sem); + mm = fault->state->mm; + address = fault->address; - if (npages == 1) { - put_page(page); - } else if (fault->dev_state->inv_ppr_cb) { - int status; - - status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, - fault->pasid, - fault->address, - fault->flags); - switch (status) { - case AMD_IOMMU_INV_PRI_RSP_SUCCESS: - set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); - break; - case AMD_IOMMU_INV_PRI_RSP_INVALID: - set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); - break; - case AMD_IOMMU_INV_PRI_RSP_FAIL: - set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); - break; - default: - BUG(); - } - } else { - set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); + down_read(&mm->mmap_sem); + vma = find_extend_vma(mm, address); + if (!vma || address < vma->vm_start) { + /* failed to get a vma in the right range */ + up_read(&mm->mmap_sem); + handle_fault_error(fault); + goto out; } + ret = handle_mm_fault(mm, vma, address, write); + if (ret & VM_FAULT_ERROR) { + /* failed to service fault */ + up_read(&mm->mmap_sem); + handle_fault_error(fault); + goto out; + } + + up_read(&mm->mmap_sem); + +out: finish_pri_tag(fault->dev_state, fault->state, fault->tag); put_pasid_state(fault->state);