diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c index e9ba9551832c..696b2c4613a7 100644 --- a/net/mptcp/pm.c +++ b/net/mptcp/pm.c @@ -172,9 +172,28 @@ void mptcp_pm_subflow_established(struct mptcp_sock *msk) spin_unlock_bh(&pm->lock); } -void mptcp_pm_subflow_closed(struct mptcp_sock *msk, u8 id) +void mptcp_pm_subflow_check_next(struct mptcp_sock *msk, const struct sock *ssk, + const struct mptcp_subflow_context *subflow) { - pr_debug("msk=%p", msk); + struct mptcp_pm_data *pm = &msk->pm; + bool update_subflows; + + update_subflows = (ssk->sk_state == TCP_CLOSE) && + (subflow->request_join || subflow->mp_join); + if (!READ_ONCE(pm->work_pending) && !update_subflows) + return; + + spin_lock_bh(&pm->lock); + if (update_subflows) + pm->subflows--; + + /* Even if this subflow is not really established, tell the PM to try + * to pick the next ones, if possible. + */ + if (mptcp_pm_nl_check_work_pending(msk)) + mptcp_pm_schedule_work(msk, MPTCP_PM_SUBFLOW_ESTABLISHED); + + spin_unlock_bh(&pm->lock); } void mptcp_pm_add_addr_received(struct mptcp_sock *msk, diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index ad3dc9c6c531..5efb63ab1fa3 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -251,14 +251,17 @@ unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk) } EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max); -static void check_work_pending(struct mptcp_sock *msk) +bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk) { struct pm_nl_pernet *pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); if (msk->pm.subflows == mptcp_pm_get_subflows_max(msk) || (find_next_and_bit(pernet->id_bitmap, msk->pm.id_avail_bitmap, - MPTCP_PM_MAX_ADDR_ID + 1, 0) == MPTCP_PM_MAX_ADDR_ID + 1)) + MPTCP_PM_MAX_ADDR_ID + 1, 0) == MPTCP_PM_MAX_ADDR_ID + 1)) { WRITE_ONCE(msk->pm.work_pending, false); + return false; + } + return true; } struct mptcp_pm_add_entry * @@ -427,6 +430,7 @@ static bool lookup_address_in_vec(struct mptcp_addr_info *addrs, unsigned int nr static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullmesh, struct mptcp_addr_info *addrs) { + bool deny_id0 = READ_ONCE(msk->pm.remote_deny_join_id0); struct sock *sk = (struct sock *)msk, *ssk; struct mptcp_subflow_context *subflow; struct mptcp_addr_info remote = { 0 }; @@ -434,22 +438,28 @@ static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullm int i = 0; subflows_max = mptcp_pm_get_subflows_max(msk); + remote_address((struct sock_common *)sk, &remote); /* Non-fullmesh endpoint, fill in the single entry * corresponding to the primary MPC subflow remote address */ if (!fullmesh) { - remote_address((struct sock_common *)sk, &remote); + if (deny_id0) + return 0; + msk->pm.subflows++; addrs[i++] = remote; } else { mptcp_for_each_subflow(msk, subflow) { ssk = mptcp_subflow_tcp_sock(subflow); - remote_address((struct sock_common *)ssk, &remote); - if (!lookup_address_in_vec(addrs, i, &remote) && + remote_address((struct sock_common *)ssk, &addrs[i]); + if (deny_id0 && addresses_equal(&addrs[i], &remote, false)) + continue; + + if (!lookup_address_in_vec(addrs, i, &addrs[i]) && msk->pm.subflows < subflows_max) { msk->pm.subflows++; - addrs[i++] = remote; + i++; } } } @@ -503,12 +513,12 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) /* do lazy endpoint usage accounting for the MPC subflows */ if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) { - struct mptcp_addr_info local; + struct mptcp_addr_info mpc_addr; int mpc_id; - local_address((struct sock_common *)msk->first, &local); - mpc_id = lookup_id_by_addr(pernet, &local); - if (mpc_id < 0) + local_address((struct sock_common *)msk->first, &mpc_addr); + mpc_id = lookup_id_by_addr(pernet, &mpc_addr); + if (mpc_id >= 0) __clear_bit(mpc_id, msk->pm.id_avail_bitmap); msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED); @@ -534,26 +544,28 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) } /* check if should create a new subflow */ - if (msk->pm.local_addr_used < local_addr_max && - msk->pm.subflows < subflows_max && - !READ_ONCE(msk->pm.remote_deny_join_id0)) { - local = select_local_address(pernet, msk); - if (local) { - bool fullmesh = !!(local->flags & MPTCP_PM_ADDR_FLAG_FULLMESH); - struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX]; - int i, nr; + while (msk->pm.local_addr_used < local_addr_max && + msk->pm.subflows < subflows_max) { + struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX]; + bool fullmesh; + int i, nr; - msk->pm.local_addr_used++; - nr = fill_remote_addresses_vec(msk, fullmesh, addrs); - if (nr) - __clear_bit(local->addr.id, msk->pm.id_avail_bitmap); - spin_unlock_bh(&msk->pm.lock); - for (i = 0; i < nr; i++) - __mptcp_subflow_connect(sk, &local->addr, &addrs[i]); - spin_lock_bh(&msk->pm.lock); - } + local = select_local_address(pernet, msk); + if (!local) + break; + + fullmesh = !!(local->flags & MPTCP_PM_ADDR_FLAG_FULLMESH); + + msk->pm.local_addr_used++; + nr = fill_remote_addresses_vec(msk, fullmesh, addrs); + if (nr) + __clear_bit(local->addr.id, msk->pm.id_avail_bitmap); + spin_unlock_bh(&msk->pm.lock); + for (i = 0; i < nr; i++) + __mptcp_subflow_connect(sk, &local->addr, &addrs[i]); + spin_lock_bh(&msk->pm.lock); } - check_work_pending(msk); + mptcp_pm_nl_check_work_pending(msk); } static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk) @@ -760,11 +772,12 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk, i, rm_list->ids[i], subflow->local_id, subflow->remote_id); spin_unlock_bh(&msk->pm.lock); mptcp_subflow_shutdown(sk, ssk, how); + + /* the following takes care of updating the subflows counter */ mptcp_close_ssk(sk, ssk, subflow); spin_lock_bh(&msk->pm.lock); removed = true; - msk->pm.subflows--; __MPTCP_INC_STATS(sock_net(sk), rm_type); } __set_bit(rm_list->ids[1], msk->pm.id_avail_bitmap); diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 5c956a8dc714..3e8cfaed00b5 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -2332,6 +2332,12 @@ void mptcp_close_ssk(struct sock *sk, struct sock *ssk, { if (sk->sk_state == TCP_ESTABLISHED) mptcp_event(MPTCP_EVENT_SUB_CLOSED, mptcp_sk(sk), ssk, GFP_KERNEL); + + /* subflow aborted before reaching the fully_established status + * attempt the creation of the next subflow + */ + mptcp_pm_subflow_check_next(mptcp_sk(sk), ssk, subflow); + __mptcp_close_ssk(sk, ssk, subflow, MPTCP_CF_PUSH); } diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index 2a6f0960ba27..a8eb32e29215 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -743,7 +743,9 @@ void mptcp_pm_fully_established(struct mptcp_sock *msk, const struct sock *ssk, bool mptcp_pm_allow_new_subflow(struct mptcp_sock *msk); void mptcp_pm_connection_closed(struct mptcp_sock *msk); void mptcp_pm_subflow_established(struct mptcp_sock *msk); -void mptcp_pm_subflow_closed(struct mptcp_sock *msk, u8 id); +bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk); +void mptcp_pm_subflow_check_next(struct mptcp_sock *msk, const struct sock *ssk, + const struct mptcp_subflow_context *subflow); void mptcp_pm_add_addr_received(struct mptcp_sock *msk, const struct mptcp_addr_info *addr); void mptcp_pm_add_addr_echoed(struct mptcp_sock *msk,