diff --git a/include/net/sctp/sctp.h b/include/net/sctp/sctp.h index 31acc3f4f132..f0dcaebebddb 100644 --- a/include/net/sctp/sctp.h +++ b/include/net/sctp/sctp.h @@ -164,7 +164,7 @@ void sctp_backlog_migrate(struct sctp_association *assoc, struct sock *oldsk, struct sock *newsk); int sctp_transport_hashtable_init(void); void sctp_transport_hashtable_destroy(void); -void sctp_hash_transport(struct sctp_transport *t); +int sctp_hash_transport(struct sctp_transport *t); void sctp_unhash_transport(struct sctp_transport *t); struct sctp_transport *sctp_addrs_lookup_transport( struct net *net, diff --git a/include/net/sctp/structs.h b/include/net/sctp/structs.h index bd4a3ded7c87..92daabdc007d 100644 --- a/include/net/sctp/structs.h +++ b/include/net/sctp/structs.h @@ -124,7 +124,7 @@ extern struct sctp_globals { /* This is the sctp port control hash. */ struct sctp_bind_hashbucket *port_hashtable; /* This is the hash of all transports. */ - struct rhashtable transport_hashtable; + struct rhltable transport_hashtable; /* Sizes of above hashtables. */ int ep_hashsize; @@ -761,7 +761,7 @@ static inline int sctp_packet_empty(struct sctp_packet *packet) struct sctp_transport { /* A list of transports. */ struct list_head transports; - struct rhash_head node; + struct rhlist_head node; /* Reference counting. */ atomic_t refcnt; diff --git a/net/sctp/associola.c b/net/sctp/associola.c index f10d3397f917..68428e1f7181 100644 --- a/net/sctp/associola.c +++ b/net/sctp/associola.c @@ -700,11 +700,15 @@ struct sctp_transport *sctp_assoc_add_peer(struct sctp_association *asoc, /* Set the peer's active state. */ peer->state = peer_state; + /* Add this peer into the transport hashtable */ + if (sctp_hash_transport(peer)) { + sctp_transport_free(peer); + return NULL; + } + /* Attach the remote transport to our asoc. */ list_add_tail_rcu(&peer->transports, &asoc->peer.transport_addr_list); asoc->peer.transport_count++; - /* Add this peer into the transport hashtable */ - sctp_hash_transport(peer); /* If we do not yet have a primary path, set one. */ if (!asoc->peer.primary_path) { diff --git a/net/sctp/input.c b/net/sctp/input.c index a01a56ec8b8c..458e506ef84b 100644 --- a/net/sctp/input.c +++ b/net/sctp/input.c @@ -790,10 +790,9 @@ hit: /* rhashtable for transport */ struct sctp_hash_cmp_arg { - const struct sctp_endpoint *ep; - const union sctp_addr *laddr; - const union sctp_addr *paddr; - const struct net *net; + const union sctp_addr *paddr; + const struct net *net; + u16 lport; }; static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg, @@ -801,7 +800,6 @@ static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg, { struct sctp_transport *t = (struct sctp_transport *)ptr; const struct sctp_hash_cmp_arg *x = arg->key; - struct sctp_association *asoc; int err = 1; if (!sctp_cmp_addr_exact(&t->ipaddr, x->paddr)) @@ -809,19 +807,10 @@ static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg, if (!sctp_transport_hold(t)) return err; - asoc = t->asoc; - if (!net_eq(sock_net(asoc->base.sk), x->net)) + if (!net_eq(sock_net(t->asoc->base.sk), x->net)) + goto out; + if (x->lport != htons(t->asoc->base.bind_addr.port)) goto out; - if (x->ep) { - if (x->ep != asoc->ep) - goto out; - } else { - if (x->laddr->v4.sin_port != htons(asoc->base.bind_addr.port)) - goto out; - if (!sctp_bind_addr_match(&asoc->base.bind_addr, - x->laddr, sctp_sk(asoc->base.sk))) - goto out; - } err = 0; out: @@ -851,11 +840,9 @@ static inline u32 sctp_hash_key(const void *data, u32 len, u32 seed) const struct sctp_hash_cmp_arg *x = data; const union sctp_addr *paddr = x->paddr; const struct net *net = x->net; - u16 lport; + u16 lport = x->lport; u32 addr; - lport = x->ep ? htons(x->ep->base.bind_addr.port) : - x->laddr->v4.sin_port; if (paddr->sa.sa_family == AF_INET6) addr = jhash(&paddr->v6.sin6_addr, 16, seed); else @@ -875,29 +862,32 @@ static const struct rhashtable_params sctp_hash_params = { int sctp_transport_hashtable_init(void) { - return rhashtable_init(&sctp_transport_hashtable, &sctp_hash_params); + return rhltable_init(&sctp_transport_hashtable, &sctp_hash_params); } void sctp_transport_hashtable_destroy(void) { - rhashtable_destroy(&sctp_transport_hashtable); + rhltable_destroy(&sctp_transport_hashtable); } -void sctp_hash_transport(struct sctp_transport *t) +int sctp_hash_transport(struct sctp_transport *t) { struct sctp_hash_cmp_arg arg; + int err; if (t->asoc->temp) - return; + return 0; - arg.ep = t->asoc->ep; - arg.paddr = &t->ipaddr; arg.net = sock_net(t->asoc->base.sk); + arg.paddr = &t->ipaddr; + arg.lport = htons(t->asoc->base.bind_addr.port); -reinsert: - if (rhashtable_lookup_insert_key(&sctp_transport_hashtable, &arg, - &t->node, sctp_hash_params) == -EBUSY) - goto reinsert; + err = rhltable_insert_key(&sctp_transport_hashtable, &arg, + &t->node, sctp_hash_params); + if (err) + pr_err_once("insert transport fail, errno %d\n", err); + + return err; } void sctp_unhash_transport(struct sctp_transport *t) @@ -905,39 +895,62 @@ void sctp_unhash_transport(struct sctp_transport *t) if (t->asoc->temp) return; - rhashtable_remove_fast(&sctp_transport_hashtable, &t->node, - sctp_hash_params); + rhltable_remove(&sctp_transport_hashtable, &t->node, + sctp_hash_params); } +/* return a transport with holding it */ struct sctp_transport *sctp_addrs_lookup_transport( struct net *net, const union sctp_addr *laddr, const union sctp_addr *paddr) { + struct rhlist_head *tmp, *list; + struct sctp_transport *t; struct sctp_hash_cmp_arg arg = { - .ep = NULL, - .laddr = laddr, .paddr = paddr, .net = net, + .lport = laddr->v4.sin_port, }; - return rhashtable_lookup_fast(&sctp_transport_hashtable, &arg, - sctp_hash_params); + list = rhltable_lookup(&sctp_transport_hashtable, &arg, + sctp_hash_params); + + rhl_for_each_entry_rcu(t, tmp, list, node) { + if (!sctp_transport_hold(t)) + continue; + + if (sctp_bind_addr_match(&t->asoc->base.bind_addr, + laddr, sctp_sk(t->asoc->base.sk))) + return t; + sctp_transport_put(t); + } + + return NULL; } +/* return a transport without holding it, as it's only used under sock lock */ struct sctp_transport *sctp_epaddr_lookup_transport( const struct sctp_endpoint *ep, const union sctp_addr *paddr) { struct net *net = sock_net(ep->base.sk); + struct rhlist_head *tmp, *list; + struct sctp_transport *t; struct sctp_hash_cmp_arg arg = { - .ep = ep, .paddr = paddr, .net = net, + .lport = htons(ep->base.bind_addr.port), }; - return rhashtable_lookup_fast(&sctp_transport_hashtable, &arg, - sctp_hash_params); + list = rhltable_lookup(&sctp_transport_hashtable, &arg, + sctp_hash_params); + + rhl_for_each_entry_rcu(t, tmp, list, node) + if (ep == t->asoc->ep) + return t; + + return NULL; } /* Look up an association. */ @@ -951,7 +964,7 @@ static struct sctp_association *__sctp_lookup_association( struct sctp_association *asoc = NULL; t = sctp_addrs_lookup_transport(net, local, peer); - if (!t || !sctp_transport_hold(t)) + if (!t) goto out; asoc = t->asoc; diff --git a/net/sctp/socket.c b/net/sctp/socket.c index f23ad913dc7a..d5f4b4a8369b 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -4392,10 +4392,7 @@ int sctp_transport_walk_start(struct rhashtable_iter *iter) { int err; - err = rhashtable_walk_init(&sctp_transport_hashtable, iter, - GFP_KERNEL); - if (err) - return err; + rhltable_walk_enter(&sctp_transport_hashtable, iter); err = rhashtable_walk_start(iter); if (err && err != -EAGAIN) { @@ -4479,7 +4476,7 @@ int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *), rcu_read_lock(); transport = sctp_addrs_lookup_transport(net, laddr, paddr); - if (!transport || !sctp_transport_hold(transport)) + if (!transport) goto out; rcu_read_unlock();