diff --git a/fs/namespace.c b/fs/namespace.c index ef1fd6829814..437f60e96d40 100644 --- a/fs/namespace.c +++ b/fs/namespace.c @@ -5042,13 +5042,12 @@ static struct mount *listmnt_next(struct mount *curr) return node_to_mount(rb_next(&curr->mnt_node)); } -static ssize_t do_listmount(struct mount *first, struct path *orig, u64 mnt_id, - u64 __user *buf, size_t bufsize, - const struct path *root) +static ssize_t do_listmount(struct mount *first, struct path *orig, + u64 mnt_parent_id, u64 __user *mnt_ids, + size_t nr_mnt_ids, const struct path *root) { struct mount *r; - ssize_t ctr; - int err; + ssize_t ret; /* * Don't trigger audit denials. We just want to determine what @@ -5058,50 +5057,57 @@ static ssize_t do_listmount(struct mount *first, struct path *orig, u64 mnt_id, !ns_capable_noaudit(&init_user_ns, CAP_SYS_ADMIN)) return -EPERM; - err = security_sb_statfs(orig->dentry); - if (err) - return err; + ret = security_sb_statfs(orig->dentry); + if (ret) + return ret; - for (ctr = 0, r = first; r && ctr < bufsize; r = listmnt_next(r)) { - if (r->mnt_id_unique == mnt_id) + for (ret = 0, r = first; r && nr_mnt_ids; r = listmnt_next(r)) { + if (r->mnt_id_unique == mnt_parent_id) continue; if (!is_path_reachable(r, r->mnt.mnt_root, orig)) continue; - ctr = array_index_nospec(ctr, bufsize); - if (put_user(r->mnt_id_unique, buf + ctr)) + if (put_user(r->mnt_id_unique, mnt_ids)) return -EFAULT; - if (check_add_overflow(ctr, 1, &ctr)) - return -ERANGE; + mnt_ids++; + nr_mnt_ids--; + ret++; } - return ctr; + return ret; } -SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req, - u64 __user *, buf, size_t, bufsize, unsigned int, flags) +SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req, u64 __user *, + mnt_ids, size_t, nr_mnt_ids, unsigned int, flags) { struct mnt_namespace *ns = current->nsproxy->mnt_ns; struct mnt_id_req kreq; struct mount *first; struct path root, orig; - u64 mnt_id, last_mnt_id; + u64 mnt_parent_id, last_mnt_id; + const size_t maxcount = (size_t)-1 >> 3; ssize_t ret; if (flags) return -EINVAL; + if (unlikely(nr_mnt_ids > maxcount)) + return -EFAULT; + + if (!access_ok(mnt_ids, nr_mnt_ids * sizeof(*mnt_ids))) + return -EFAULT; + ret = copy_mnt_id_req(req, &kreq); if (ret) return ret; - mnt_id = kreq.mnt_id; + mnt_parent_id = kreq.mnt_id; last_mnt_id = kreq.param; down_read(&namespace_sem); get_fs_root(current->fs, &root); - if (mnt_id == LSMT_ROOT) { + if (mnt_parent_id == LSMT_ROOT) { orig = root; } else { ret = -ENOENT; - orig.mnt = lookup_mnt_in_ns(mnt_id, ns); + orig.mnt = lookup_mnt_in_ns(mnt_parent_id, ns); if (!orig.mnt) goto err; orig.dentry = orig.mnt->mnt_root; @@ -5111,7 +5117,7 @@ SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req, else first = mnt_find_id_at(ns, last_mnt_id + 1); - ret = do_listmount(first, &orig, mnt_id, buf, bufsize, &root); + ret = do_listmount(first, &orig, mnt_parent_id, mnt_ids, nr_mnt_ids, &root); err: path_put(&root); up_read(&namespace_sem); diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h index 5c0dbef55792..cdba4d0c6d4a 100644 --- a/include/linux/syscalls.h +++ b/include/linux/syscalls.h @@ -414,7 +414,7 @@ asmlinkage long sys_statmount(const struct mnt_id_req __user *req, struct statmount __user *buf, size_t bufsize, unsigned int flags); asmlinkage long sys_listmount(const struct mnt_id_req __user *req, - u64 __user *buf, size_t bufsize, + u64 __user *mnt_ids, size_t nr_mnt_ids, unsigned int flags); asmlinkage long sys_truncate(const char __user *path, long length); asmlinkage long sys_ftruncate(unsigned int fd, unsigned long length);