diff --git a/drivers/misc/habanalabs/common/habanalabs.h b/drivers/misc/habanalabs/common/habanalabs.h index 018d9d67e8e6..13c18f3d9a9b 100644 --- a/drivers/misc/habanalabs/common/habanalabs.h +++ b/drivers/misc/habanalabs/common/habanalabs.h @@ -1651,7 +1651,7 @@ struct hl_ioctl_desc { * * Return: true if the area is inside the valid range, false otherwise. */ -static inline bool hl_mem_area_inside_range(u64 address, u32 size, +static inline bool hl_mem_area_inside_range(u64 address, u64 size, u64 range_start_address, u64 range_end_address) { u64 end_address = address + size; diff --git a/drivers/misc/habanalabs/gaudi/gaudi_coresight.c b/drivers/misc/habanalabs/gaudi/gaudi_coresight.c index 5673ee49819e..881531d4d9da 100644 --- a/drivers/misc/habanalabs/gaudi/gaudi_coresight.c +++ b/drivers/misc/habanalabs/gaudi/gaudi_coresight.c @@ -527,7 +527,7 @@ static int gaudi_config_etf(struct hl_device *hdev, } static bool gaudi_etr_validate_address(struct hl_device *hdev, u64 addr, - u32 size, bool *is_host) + u64 size, bool *is_host) { struct asic_fixed_properties *prop = &hdev->asic_prop; struct gaudi_device *gaudi = hdev->asic_specific; @@ -539,6 +539,12 @@ static bool gaudi_etr_validate_address(struct hl_device *hdev, u64 addr, return false; } + if (addr > (addr + size)) { + dev_err(hdev->dev, + "ETR buffer size %llu overflow\n", size); + return false; + } + /* PMMU and HPMMU addresses are equal, check only one of them */ if ((gaudi->hw_cap_initialized & HW_CAP_MMU) && hl_mem_area_inside_range(addr, size, diff --git a/drivers/misc/habanalabs/goya/goya_coresight.c b/drivers/misc/habanalabs/goya/goya_coresight.c index b03912483de0..4027a6a334d7 100644 --- a/drivers/misc/habanalabs/goya/goya_coresight.c +++ b/drivers/misc/habanalabs/goya/goya_coresight.c @@ -362,11 +362,17 @@ static int goya_config_etf(struct hl_device *hdev, } static int goya_etr_validate_address(struct hl_device *hdev, u64 addr, - u32 size) + u64 size) { struct asic_fixed_properties *prop = &hdev->asic_prop; u64 range_start, range_end; + if (addr > (addr + size)) { + dev_err(hdev->dev, + "ETR buffer size %llu overflow\n", size); + return false; + } + if (hdev->mmu_enable) { range_start = prop->dmmu.start_addr; range_end = prop->dmmu.end_addr;