diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 6431b994b3f6..b23812d2bb49 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -12946,6 +12946,13 @@ static bool check_ids(u32 old_id, u32 cur_id, struct bpf_id_pair *idmap) { unsigned int i; + /* either both IDs should be set or both should be zero */ + if (!!old_id != !!cur_id) + return false; + + if (old_id == 0) /* cur_id == 0 as well */ + return true; + for (i = 0; i < BPF_ID_MAP_SIZE; i++) { if (!idmap[i].old) { /* Reached an empty slot; haven't seen this id before */ @@ -13058,9 +13065,12 @@ next: } static bool regs_exact(const struct bpf_reg_state *rold, - const struct bpf_reg_state *rcur) + const struct bpf_reg_state *rcur, + struct bpf_id_pair *idmap) { - return memcmp(rold, rcur, offsetof(struct bpf_reg_state, parent)) == 0; + return memcmp(rold, rcur, offsetof(struct bpf_reg_state, id)) == 0 && + check_ids(rold->id, rcur->id, idmap) && + check_ids(rold->ref_obj_id, rcur->ref_obj_id, idmap); } /* Returns true if (rold safe implies rcur safe) */ @@ -13102,7 +13112,7 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold, switch (base_type(rold->type)) { case SCALAR_VALUE: - if (regs_exact(rold, rcur)) + if (regs_exact(rold, rcur, idmap)) return true; if (env->explore_alu_limits) return false; @@ -13136,7 +13146,7 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold, if (rold->off != rcur->off) return false; /* id relations must be preserved */ - if (rold->id && !check_ids(rold->id, rcur->id, idmap)) + if (!check_ids(rold->id, rcur->id, idmap)) return false; /* new val must satisfy old val knowledge */ return range_within(rold, rcur) && @@ -13145,10 +13155,9 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold, /* two stack pointers are equal only if they're pointing to * the same stack frame, since fp-8 in foo != fp-8 in bar */ - return regs_exact(rold, rcur) && rold->frameno == rcur->frameno; + return regs_exact(rold, rcur, idmap) && rold->frameno == rcur->frameno; default: - /* Only valid matches are exact, which memcmp() */ - return regs_exact(rold, rcur); + return regs_exact(rold, rcur, idmap); } }