diff --git a/io_uring/kbuf.c b/io_uring/kbuf.c index 3002dc827195..3adc08f90e41 100644 --- a/io_uring/kbuf.c +++ b/io_uring/kbuf.c @@ -463,14 +463,32 @@ err: return IOU_OK; } -int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg) +static int io_pin_pbuf_ring(struct io_uring_buf_reg *reg, + struct io_buffer_list *bl) { struct io_uring_buf_ring *br; - struct io_uring_buf_reg reg; - struct io_buffer_list *bl, *free_bl = NULL; struct page **pages; int nr_pages; + pages = io_pin_pages(reg->ring_addr, + flex_array_size(br, bufs, reg->ring_entries), + &nr_pages); + if (IS_ERR(pages)) + return PTR_ERR(pages); + + br = page_address(pages[0]); + bl->buf_pages = pages; + bl->buf_nr_pages = nr_pages; + bl->buf_ring = br; + return 0; +} + +int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg) +{ + struct io_uring_buf_reg reg; + struct io_buffer_list *bl, *free_bl = NULL; + int ret; + if (copy_from_user(®, arg, sizeof(reg))) return -EFAULT; @@ -504,20 +522,15 @@ int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg) return -ENOMEM; } - pages = io_pin_pages(reg.ring_addr, - flex_array_size(br, bufs, reg.ring_entries), - &nr_pages); - if (IS_ERR(pages)) { + ret = io_pin_pbuf_ring(®, bl); + if (ret) { kfree(free_bl); - return PTR_ERR(pages); + return ret; } - br = page_address(pages[0]); - bl->buf_pages = pages; - bl->buf_nr_pages = nr_pages; bl->nr_entries = reg.ring_entries; - bl->buf_ring = br; bl->mask = reg.ring_entries - 1; + io_buffer_add_list(ctx, bl, reg.bgid); return 0; }