fs: rework listmount() implementation

Linus pointed out that there's error handling and naming issues in the
that we should rewrite:

* Perform the access checks for the buffer before actually doing any
  work instead of doing it during the iteration.
* Rename the arguments to listmount() and do_listmount() to clarify what
  the arguments are used for.
* Get rid of the pointless ctr variable and overflow checking.
* Get rid of the pointless speculation check.

Link: https://lore.kernel.org/r/CAHk-=wjh6Cypo8WC-McXgSzCaou3UXccxB+7PVeSuGR8AjCphg@mail.gmail.com
Suggested-by: Linus Torvalds <torvalds@linux-foundation.org>
Signed-off-by: Christian Brauner <brauner@kernel.org>
This commit is contained in:
Christian Brauner 2024-01-12 09:09:14 +01:00
parent 7ea26f9460
commit ba5afb9a84
No known key found for this signature in database
GPG Key ID: 91C61BC06578DCA2
2 changed files with 29 additions and 23 deletions

View File

@ -5042,13 +5042,12 @@ static struct mount *listmnt_next(struct mount *curr)
return node_to_mount(rb_next(&curr->mnt_node)); return node_to_mount(rb_next(&curr->mnt_node));
} }
static ssize_t do_listmount(struct mount *first, struct path *orig, u64 mnt_id, static ssize_t do_listmount(struct mount *first, struct path *orig,
u64 __user *buf, size_t bufsize, u64 mnt_parent_id, u64 __user *mnt_ids,
const struct path *root) size_t nr_mnt_ids, const struct path *root)
{ {
struct mount *r; struct mount *r;
ssize_t ctr; ssize_t ret;
int err;
/* /*
* Don't trigger audit denials. We just want to determine what * Don't trigger audit denials. We just want to determine what
@ -5058,50 +5057,57 @@ static ssize_t do_listmount(struct mount *first, struct path *orig, u64 mnt_id,
!ns_capable_noaudit(&init_user_ns, CAP_SYS_ADMIN)) !ns_capable_noaudit(&init_user_ns, CAP_SYS_ADMIN))
return -EPERM; return -EPERM;
err = security_sb_statfs(orig->dentry); ret = security_sb_statfs(orig->dentry);
if (err) if (ret)
return err; return ret;
for (ctr = 0, r = first; r && ctr < bufsize; r = listmnt_next(r)) { for (ret = 0, r = first; r && nr_mnt_ids; r = listmnt_next(r)) {
if (r->mnt_id_unique == mnt_id) if (r->mnt_id_unique == mnt_parent_id)
continue; continue;
if (!is_path_reachable(r, r->mnt.mnt_root, orig)) if (!is_path_reachable(r, r->mnt.mnt_root, orig))
continue; continue;
ctr = array_index_nospec(ctr, bufsize); if (put_user(r->mnt_id_unique, mnt_ids))
if (put_user(r->mnt_id_unique, buf + ctr))
return -EFAULT; return -EFAULT;
if (check_add_overflow(ctr, 1, &ctr)) mnt_ids++;
return -ERANGE; nr_mnt_ids--;
ret++;
} }
return ctr; return ret;
} }
SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req, SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req, u64 __user *,
u64 __user *, buf, size_t, bufsize, unsigned int, flags) mnt_ids, size_t, nr_mnt_ids, unsigned int, flags)
{ {
struct mnt_namespace *ns = current->nsproxy->mnt_ns; struct mnt_namespace *ns = current->nsproxy->mnt_ns;
struct mnt_id_req kreq; struct mnt_id_req kreq;
struct mount *first; struct mount *first;
struct path root, orig; struct path root, orig;
u64 mnt_id, last_mnt_id; u64 mnt_parent_id, last_mnt_id;
const size_t maxcount = (size_t)-1 >> 3;
ssize_t ret; ssize_t ret;
if (flags) if (flags)
return -EINVAL; return -EINVAL;
if (unlikely(nr_mnt_ids > maxcount))
return -EFAULT;
if (!access_ok(mnt_ids, nr_mnt_ids * sizeof(*mnt_ids)))
return -EFAULT;
ret = copy_mnt_id_req(req, &kreq); ret = copy_mnt_id_req(req, &kreq);
if (ret) if (ret)
return ret; return ret;
mnt_id = kreq.mnt_id; mnt_parent_id = kreq.mnt_id;
last_mnt_id = kreq.param; last_mnt_id = kreq.param;
down_read(&namespace_sem); down_read(&namespace_sem);
get_fs_root(current->fs, &root); get_fs_root(current->fs, &root);
if (mnt_id == LSMT_ROOT) { if (mnt_parent_id == LSMT_ROOT) {
orig = root; orig = root;
} else { } else {
ret = -ENOENT; ret = -ENOENT;
orig.mnt = lookup_mnt_in_ns(mnt_id, ns); orig.mnt = lookup_mnt_in_ns(mnt_parent_id, ns);
if (!orig.mnt) if (!orig.mnt)
goto err; goto err;
orig.dentry = orig.mnt->mnt_root; orig.dentry = orig.mnt->mnt_root;
@ -5111,7 +5117,7 @@ SYSCALL_DEFINE4(listmount, const struct mnt_id_req __user *, req,
else else
first = mnt_find_id_at(ns, last_mnt_id + 1); first = mnt_find_id_at(ns, last_mnt_id + 1);
ret = do_listmount(first, &orig, mnt_id, buf, bufsize, &root); ret = do_listmount(first, &orig, mnt_parent_id, mnt_ids, nr_mnt_ids, &root);
err: err:
path_put(&root); path_put(&root);
up_read(&namespace_sem); up_read(&namespace_sem);

View File

@ -414,7 +414,7 @@ asmlinkage long sys_statmount(const struct mnt_id_req __user *req,
struct statmount __user *buf, size_t bufsize, struct statmount __user *buf, size_t bufsize,
unsigned int flags); unsigned int flags);
asmlinkage long sys_listmount(const struct mnt_id_req __user *req, asmlinkage long sys_listmount(const struct mnt_id_req __user *req,
u64 __user *buf, size_t bufsize, u64 __user *mnt_ids, size_t nr_mnt_ids,
unsigned int flags); unsigned int flags);
asmlinkage long sys_truncate(const char __user *path, long length); asmlinkage long sys_truncate(const char __user *path, long length);
asmlinkage long sys_ftruncate(unsigned int fd, unsigned long length); asmlinkage long sys_ftruncate(unsigned int fd, unsigned long length);