diff --git a/include/linux/tcp_diag.h b/include/linux/tcp_diag.h index 190494ebcfb8..910c34ba19c0 100644 --- a/include/linux/tcp_diag.h +++ b/include/linux/tcp_diag.h @@ -5,6 +5,8 @@ #define TCPDIAG_GETSOCK 18 #define DCCPDIAG_GETSOCK 19 +#define INET_DIAG_GETSOCK_MAX 24 + /* Socket identity */ struct tcpdiag_sockid { @@ -125,4 +127,21 @@ struct tcpvegas_info { __u32 tcpv_minrtt; }; +#ifdef __KERNEL__ +struct sock; +struct inet_hashinfo; + +struct inet_diag_handler { + struct inet_hashinfo *idiag_hashinfo; + void (*idiag_get_info)(struct sock *sk, + struct tcpdiagmsg *r, + void *info); + __u16 idiag_info_size; + __u16 idiag_type; +}; + +extern int inet_diag_register(const struct inet_diag_handler *handler); +extern void inet_diag_unregister(const struct inet_diag_handler *handler); +#endif /* __KERNEL__ */ + #endif /* _TCP_DIAG_H_ */ diff --git a/net/dccp/Kconfig b/net/dccp/Kconfig index 90460bc629b3..ff5b5459b97a 100644 --- a/net/dccp/Kconfig +++ b/net/dccp/Kconfig @@ -19,6 +19,11 @@ config IP_DCCP If in doubt, say N. +config IP_DCCP_DIAG + depends on IP_DCCP && IP_TCPDIAG + def_tristate y if (IP_DCCP = y && IP_TCPDIAG = y) + def_tristate m + source "net/dccp/ccids/Kconfig" endmenu diff --git a/net/dccp/Makefile b/net/dccp/Makefile index 25a50bdbf1bb..5741fffc436f 100644 --- a/net/dccp/Makefile +++ b/net/dccp/Makefile @@ -3,4 +3,8 @@ obj-$(CONFIG_IP_DCCP) += dccp.o dccp-y := ccid.o input.o ipv4.o minisocks.o options.o output.o proto.o \ timer.o packet_history.o +obj-$(CONFIG_IP_DCCP_DIAG) += dccp_diag.o + obj-y += ccids/ + +dccp_diag-y := diag.o diff --git a/net/dccp/diag.c b/net/dccp/diag.c new file mode 100644 index 000000000000..4d9037c56ddc --- /dev/null +++ b/net/dccp/diag.c @@ -0,0 +1,47 @@ +/* + * net/dccp/diag.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + */ + +#include + +#include +#include + +#include "dccp.h" + +static void dccp_diag_get_info(struct sock *sk, struct tcpdiagmsg *r, + void *_info) +{ + r->tcpdiag_rqueue = r->tcpdiag_wqueue = 0; +} + +static struct inet_diag_handler dccp_diag_handler = { + .idiag_hashinfo = &dccp_hashinfo, + .idiag_get_info = dccp_diag_get_info, + .idiag_type = DCCPDIAG_GETSOCK, + .idiag_info_size = 0, +}; + +static int __init dccp_diag_init(void) +{ + return inet_diag_register(&dccp_diag_handler); +} + +static void __exit dccp_diag_fini(void) +{ + inet_diag_unregister(&dccp_diag_handler); +} + +module_init(dccp_diag_init); +module_exit(dccp_diag_fini); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Arnaldo Carvalho de Melo "); +MODULE_DESCRIPTION("DCCP inet_diag handler"); diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 960c02faf440..1e6db2a896b9 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -423,9 +423,6 @@ config IP_TCPDIAG If unsure, say Y. -config IP_TCPDIAG_DCCP - def_bool (IP_TCPDIAG=y && IP_DCCP=y) || (IP_TCPDIAG=m && IP_DCCP) - config TCP_CONG_ADVANCED bool "TCP: advanced congestion control" ---help--- diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c index b812191b2f5c..b13b71cb9ced 100644 --- a/net/ipv4/tcp_diag.c +++ b/net/ipv4/tcp_diag.c @@ -34,6 +34,8 @@ #include +static const struct inet_diag_handler **inet_diag_table; + struct tcpdiag_entry { u32 *saddr; @@ -61,18 +63,24 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk, const struct inet_connection_sock *icsk = inet_csk(sk); struct tcpdiagmsg *r; struct nlmsghdr *nlh; - struct tcp_info *info = NULL; + void *info = NULL; struct tcpdiag_meminfo *minfo = NULL; unsigned char *b = skb->tail; + const struct inet_diag_handler *handler; + + handler = inet_diag_table[unlh->nlmsg_type]; + BUG_ON(handler == NULL); nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r)); nlh->nlmsg_flags = nlmsg_flags; + r = NLMSG_DATA(nlh); if (sk->sk_state != TCP_TIME_WAIT) { if (ext & (1<<(TCPDIAG_MEMINFO-1))) minfo = TCPDIAG_PUT(skb, TCPDIAG_MEMINFO, sizeof(*minfo)); if (ext & (1<<(TCPDIAG_INFO-1))) - info = TCPDIAG_PUT(skb, TCPDIAG_INFO, sizeof(*info)); + info = TCPDIAG_PUT(skb, TCPDIAG_INFO, + handler->idiag_info_size); if ((ext & (1 << (TCPDIAG_CONG - 1))) && icsk->icsk_ca_ops) { size_t len = strlen(icsk->icsk_ca_ops->name); @@ -155,19 +163,6 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk, r->tcpdiag_expires = 0; } #undef EXPIRES_IN_MS - /* - * Ahem... for now we'll have some knowledge about TCP -acme - * But this is just one of two small exceptions, both in this - * function, so lets close our eyes for some 15 lines or so... 8) - * -acme - */ - if (sk->sk_protocol == IPPROTO_TCP) { - const struct tcp_sock *tp = tcp_sk(sk); - - r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq; - r->tcpdiag_wqueue = tp->write_seq - tp->snd_una; - } else - r->tcpdiag_rqueue = r->tcpdiag_wqueue = 0; r->tcpdiag_uid = sock_i_uid(sk); r->tcpdiag_inode = sock_i_ino(sk); @@ -179,13 +174,7 @@ static int tcpdiag_fill(struct sk_buff *skb, struct sock *sk, minfo->tcpdiag_tmem = atomic_read(&sk->sk_wmem_alloc); } - /* Ahem... for now we'll have some knowledge about TCP -acme */ - if (info) { - if (sk->sk_protocol == IPPROTO_TCP) - tcp_get_info(sk, info); - else - memset(info, 0, sizeof(*info)); - } + handler->idiag_get_info(sk, r, info); if (sk->sk_state < TCP_TIME_WAIT && icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info) @@ -206,11 +195,13 @@ static int tcpdiag_get_exact(struct sk_buff *in_skb, const struct nlmsghdr *nlh) struct sock *sk; struct tcpdiagreq *req = NLMSG_DATA(nlh); struct sk_buff *rep; - struct inet_hashinfo *hashinfo = &tcp_hashinfo; -#ifdef CONFIG_IP_TCPDIAG_DCCP - if (nlh->nlmsg_type == DCCPDIAG_GETSOCK) - hashinfo = &dccp_hashinfo; -#endif + struct inet_hashinfo *hashinfo; + const struct inet_diag_handler *handler; + + handler = inet_diag_table[nlh->nlmsg_type]; + BUG_ON(handler == NULL); + hashinfo = handler->idiag_hashinfo; + if (req->tcpdiag_family == AF_INET) { sk = inet_lookup(hashinfo, req->id.tcpdiag_dst[0], req->id.tcpdiag_dport, req->id.tcpdiag_src[0], @@ -241,9 +232,10 @@ static int tcpdiag_get_exact(struct sk_buff *in_skb, const struct nlmsghdr *nlh) goto out; err = -ENOMEM; - rep = alloc_skb(NLMSG_SPACE(sizeof(struct tcpdiagmsg)+ - sizeof(struct tcpdiag_meminfo)+ - sizeof(struct tcp_info)+64), GFP_KERNEL); + rep = alloc_skb(NLMSG_SPACE((sizeof(struct tcpdiagmsg) + + sizeof(struct tcpdiag_meminfo) + + handler->idiag_info_size + 64)), + GFP_KERNEL); if (!rep) goto out; @@ -603,15 +595,16 @@ static int tcpdiag_dump(struct sk_buff *skb, struct netlink_callback *cb) int i, num; int s_i, s_num; struct tcpdiagreq *r = NLMSG_DATA(cb->nlh); + const struct inet_diag_handler *handler; struct inet_hashinfo *hashinfo; + handler = inet_diag_table[cb->nlh->nlmsg_type]; + BUG_ON(handler == NULL); + hashinfo = handler->idiag_hashinfo; + s_i = cb->args[1]; s_num = num = cb->args[2]; - hashinfo = &tcp_hashinfo; -#ifdef CONFIG_IP_TCPDIAG_DCCP - if (cb->nlh->nlmsg_type == DCCPDIAG_GETSOCK) - hashinfo = &dccp_hashinfo; -#endif + if (cb->args[0] == 0) { if (!(r->tcpdiag_states&(TCPF_LISTEN|TCPF_SYN_RECV))) goto skip_listen_ht; @@ -745,13 +738,12 @@ tcpdiag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh) if (!(nlh->nlmsg_flags&NLM_F_REQUEST)) return 0; - if (nlh->nlmsg_type != TCPDIAG_GETSOCK -#ifdef CONFIG_IP_TCPDIAG_DCCP - && nlh->nlmsg_type != DCCPDIAG_GETSOCK -#endif - ) + if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX) goto err_inval; + if (inet_diag_table[nlh->nlmsg_type] == NULL) + return -ENOENT; + if (NLMSG_LENGTH(sizeof(struct tcpdiagreq)) > skb->len) goto err_inval; @@ -803,18 +795,95 @@ static void tcpdiag_rcv(struct sock *sk, int len) } } +static void tcp_diag_get_info(struct sock *sk, struct tcpdiagmsg *r, + void *_info) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_info *info = _info; + + r->tcpdiag_rqueue = tp->rcv_nxt - tp->copied_seq; + r->tcpdiag_wqueue = tp->write_seq - tp->snd_una; + if (info != NULL) + tcp_get_info(sk, info); +} + +static struct inet_diag_handler tcp_diag_handler = { + .idiag_hashinfo = &tcp_hashinfo, + .idiag_get_info = tcp_diag_get_info, + .idiag_type = TCPDIAG_GETSOCK, + .idiag_info_size = sizeof(struct tcp_info), +}; + +static DEFINE_SPINLOCK(inet_diag_register_lock); + +int inet_diag_register(const struct inet_diag_handler *h) +{ + const __u16 type = h->idiag_type; + int err = -EINVAL; + + if (type >= INET_DIAG_GETSOCK_MAX) + goto out; + + spin_lock(&inet_diag_register_lock); + err = -EEXIST; + if (inet_diag_table[type] == NULL) { + inet_diag_table[type] = h; + err = 0; + } + spin_unlock(&inet_diag_register_lock); +out: + return err; +} +EXPORT_SYMBOL_GPL(inet_diag_register); + +void inet_diag_unregister(const struct inet_diag_handler *h) +{ + const __u16 type = h->idiag_type; + + if (type >= INET_DIAG_GETSOCK_MAX) + return; + + spin_lock(&inet_diag_register_lock); + inet_diag_table[type] = NULL; + spin_unlock(&inet_diag_register_lock); + + synchronize_rcu(); +} +EXPORT_SYMBOL_GPL(inet_diag_unregister); + static int __init tcpdiag_init(void) { + const int inet_diag_table_size = (INET_DIAG_GETSOCK_MAX * + sizeof(struct inet_diag_handler *)); + int err = -ENOMEM; + + inet_diag_table = kmalloc(inet_diag_table_size, GFP_KERNEL); + if (!inet_diag_table) + goto out; + + memset(inet_diag_table, 0, inet_diag_table_size); + tcpnl = netlink_kernel_create(NETLINK_TCPDIAG, tcpdiag_rcv, THIS_MODULE); if (tcpnl == NULL) - return -ENOMEM; - return 0; + goto out_free_table; + + err = inet_diag_register(&tcp_diag_handler); + if (err) + goto out_sock_release; +out: + return err; +out_sock_release: + sock_release(tcpnl->sk_socket); +out_free_table: + kfree(inet_diag_table); + goto out; } static void __exit tcpdiag_exit(void) { sock_release(tcpnl->sk_socket); + kfree(inet_diag_table); } module_init(tcpdiag_init);