bpf: Add bpf_dctcp example

This patch adds a bpf_dctcp example.  It currently does not do
no-ECN fallback but the same could be done through the cgrp2-bpf.

Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/bpf/20200109003517.3856825-1-kafai@fb.com
This commit is contained in:
Martin KaFai Lau 2020-01-08 16:35:17 -08:00 committed by Alexei Starovoitov
parent 590a008882
commit 09903869f6
3 changed files with 625 additions and 0 deletions

View File

@ -0,0 +1,228 @@
/* SPDX-License-Identifier: GPL-2.0 */
#ifndef __BPF_TCP_HELPERS_H
#define __BPF_TCP_HELPERS_H
#include <stdbool.h>
#include <linux/types.h>
#include <bpf_helpers.h>
#include <bpf_core_read.h>
#include "bpf_trace_helpers.h"
/* "struct_ops/" is only a convention. not a requirement. */
#define BPF_TCP_OPS_0(fname, ret_type, ...) BPF_TRACE_x(0, "struct_ops/"#fname, fname, ret_type, __VA_ARGS__)
#define BPF_TCP_OPS_1(fname, ret_type, ...) BPF_TRACE_x(1, "struct_ops/"#fname, fname, ret_type, __VA_ARGS__)
#define BPF_TCP_OPS_2(fname, ret_type, ...) BPF_TRACE_x(2, "struct_ops/"#fname, fname, ret_type, __VA_ARGS__)
#define BPF_TCP_OPS_3(fname, ret_type, ...) BPF_TRACE_x(3, "struct_ops/"#fname, fname, ret_type, __VA_ARGS__)
#define BPF_TCP_OPS_4(fname, ret_type, ...) BPF_TRACE_x(4, "struct_ops/"#fname, fname, ret_type, __VA_ARGS__)
#define BPF_TCP_OPS_5(fname, ret_type, ...) BPF_TRACE_x(5, "struct_ops/"#fname, fname, ret_type, __VA_ARGS__)
struct sock_common {
unsigned char skc_state;
} __attribute__((preserve_access_index));
struct sock {
struct sock_common __sk_common;
} __attribute__((preserve_access_index));
struct inet_sock {
struct sock sk;
} __attribute__((preserve_access_index));
struct inet_connection_sock {
struct inet_sock icsk_inet;
__u8 icsk_ca_state:6,
icsk_ca_setsockopt:1,
icsk_ca_dst_locked:1;
struct {
__u8 pending;
} icsk_ack;
__u64 icsk_ca_priv[104 / sizeof(__u64)];
} __attribute__((preserve_access_index));
struct tcp_sock {
struct inet_connection_sock inet_conn;
__u32 rcv_nxt;
__u32 snd_nxt;
__u32 snd_una;
__u8 ecn_flags;
__u32 delivered;
__u32 delivered_ce;
__u32 snd_cwnd;
__u32 snd_cwnd_cnt;
__u32 snd_cwnd_clamp;
__u32 snd_ssthresh;
__u8 syn_data:1, /* SYN includes data */
syn_fastopen:1, /* SYN includes Fast Open option */
syn_fastopen_exp:1,/* SYN includes Fast Open exp. option */
syn_fastopen_ch:1, /* Active TFO re-enabling probe */
syn_data_acked:1,/* data in SYN is acked by SYN-ACK */
save_syn:1, /* Save headers of SYN packet */
is_cwnd_limited:1,/* forward progress limited by snd_cwnd? */
syn_smc:1; /* SYN includes SMC */
__u32 max_packets_out;
__u32 lsndtime;
__u32 prior_cwnd;
} __attribute__((preserve_access_index));
static __always_inline struct inet_connection_sock *inet_csk(const struct sock *sk)
{
return (struct inet_connection_sock *)sk;
}
static __always_inline void *inet_csk_ca(const struct sock *sk)
{
return (void *)inet_csk(sk)->icsk_ca_priv;
}
static __always_inline struct tcp_sock *tcp_sk(const struct sock *sk)
{
return (struct tcp_sock *)sk;
}
static __always_inline bool before(__u32 seq1, __u32 seq2)
{
return (__s32)(seq1-seq2) < 0;
}
#define after(seq2, seq1) before(seq1, seq2)
#define TCP_ECN_OK 1
#define TCP_ECN_QUEUE_CWR 2
#define TCP_ECN_DEMAND_CWR 4
#define TCP_ECN_SEEN 8
enum inet_csk_ack_state_t {
ICSK_ACK_SCHED = 1,
ICSK_ACK_TIMER = 2,
ICSK_ACK_PUSHED = 4,
ICSK_ACK_PUSHED2 = 8,
ICSK_ACK_NOW = 16 /* Send the next ACK immediately (once) */
};
enum tcp_ca_event {
CA_EVENT_TX_START = 0,
CA_EVENT_CWND_RESTART = 1,
CA_EVENT_COMPLETE_CWR = 2,
CA_EVENT_LOSS = 3,
CA_EVENT_ECN_NO_CE = 4,
CA_EVENT_ECN_IS_CE = 5,
};
enum tcp_ca_state {
TCP_CA_Open = 0,
TCP_CA_Disorder = 1,
TCP_CA_CWR = 2,
TCP_CA_Recovery = 3,
TCP_CA_Loss = 4
};
struct ack_sample {
__u32 pkts_acked;
__s32 rtt_us;
__u32 in_flight;
} __attribute__((preserve_access_index));
struct rate_sample {
__u64 prior_mstamp; /* starting timestamp for interval */
__u32 prior_delivered; /* tp->delivered at "prior_mstamp" */
__s32 delivered; /* number of packets delivered over interval */
long interval_us; /* time for tp->delivered to incr "delivered" */
__u32 snd_interval_us; /* snd interval for delivered packets */
__u32 rcv_interval_us; /* rcv interval for delivered packets */
long rtt_us; /* RTT of last (S)ACKed packet (or -1) */
int losses; /* number of packets marked lost upon ACK */
__u32 acked_sacked; /* number of packets newly (S)ACKed upon ACK */
__u32 prior_in_flight; /* in flight before this ACK */
bool is_app_limited; /* is sample from packet with bubble in pipe? */
bool is_retrans; /* is sample from retransmission? */
bool is_ack_delayed; /* is this (likely) a delayed ACK? */
} __attribute__((preserve_access_index));
#define TCP_CA_NAME_MAX 16
#define TCP_CONG_NEEDS_ECN 0x2
struct tcp_congestion_ops {
char name[TCP_CA_NAME_MAX];
__u32 flags;
/* initialize private data (optional) */
void (*init)(struct sock *sk);
/* cleanup private data (optional) */
void (*release)(struct sock *sk);
/* return slow start threshold (required) */
__u32 (*ssthresh)(struct sock *sk);
/* do new cwnd calculation (required) */
void (*cong_avoid)(struct sock *sk, __u32 ack, __u32 acked);
/* call before changing ca_state (optional) */
void (*set_state)(struct sock *sk, __u8 new_state);
/* call when cwnd event occurs (optional) */
void (*cwnd_event)(struct sock *sk, enum tcp_ca_event ev);
/* call when ack arrives (optional) */
void (*in_ack_event)(struct sock *sk, __u32 flags);
/* new value of cwnd after loss (required) */
__u32 (*undo_cwnd)(struct sock *sk);
/* hook for packet ack accounting (optional) */
void (*pkts_acked)(struct sock *sk, const struct ack_sample *sample);
/* override sysctl_tcp_min_tso_segs */
__u32 (*min_tso_segs)(struct sock *sk);
/* returns the multiplier used in tcp_sndbuf_expand (optional) */
__u32 (*sndbuf_expand)(struct sock *sk);
/* call when packets are delivered to update cwnd and pacing rate,
* after all the ca_state processing. (optional)
*/
void (*cong_control)(struct sock *sk, const struct rate_sample *rs);
};
#define min(a, b) ((a) < (b) ? (a) : (b))
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min_not_zero(x, y) ({ \
typeof(x) __x = (x); \
typeof(y) __y = (y); \
__x == 0 ? __y : ((__y == 0) ? __x : min(__x, __y)); })
static __always_inline __u32 tcp_slow_start(struct tcp_sock *tp, __u32 acked)
{
__u32 cwnd = min(tp->snd_cwnd + acked, tp->snd_ssthresh);
acked -= cwnd - tp->snd_cwnd;
tp->snd_cwnd = min(cwnd, tp->snd_cwnd_clamp);
return acked;
}
static __always_inline bool tcp_in_slow_start(const struct tcp_sock *tp)
{
return tp->snd_cwnd < tp->snd_ssthresh;
}
static __always_inline bool tcp_is_cwnd_limited(const struct sock *sk)
{
const struct tcp_sock *tp = tcp_sk(sk);
/* If in slow start, ensure cwnd grows to twice what was ACKed. */
if (tcp_in_slow_start(tp))
return tp->snd_cwnd < 2 * tp->max_packets_out;
return !!BPF_CORE_READ_BITFIELD(tp, is_cwnd_limited);
}
static __always_inline void tcp_cong_avoid_ai(struct tcp_sock *tp, __u32 w, __u32 acked)
{
/* If credits accumulated at a higher w, apply them gently now. */
if (tp->snd_cwnd_cnt >= w) {
tp->snd_cwnd_cnt = 0;
tp->snd_cwnd++;
}
tp->snd_cwnd_cnt += acked;
if (tp->snd_cwnd_cnt >= w) {
__u32 delta = tp->snd_cwnd_cnt / w;
tp->snd_cwnd_cnt -= delta * w;
tp->snd_cwnd += delta;
}
tp->snd_cwnd = min(tp->snd_cwnd, tp->snd_cwnd_clamp);
}
#endif

View File

@ -0,0 +1,187 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2019 Facebook */
#include <linux/err.h>
#include <test_progs.h>
#include "bpf_dctcp.skel.h"
#define min(a, b) ((a) < (b) ? (a) : (b))
static const unsigned int total_bytes = 10 * 1024 * 1024;
static const struct timeval timeo_sec = { .tv_sec = 10 };
static const size_t timeo_optlen = sizeof(timeo_sec);
static int stop, duration;
static int settimeo(int fd)
{
int err;
err = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec,
timeo_optlen);
if (CHECK(err == -1, "setsockopt(fd, SO_RCVTIMEO)", "errno:%d\n",
errno))
return -1;
err = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeo_sec,
timeo_optlen);
if (CHECK(err == -1, "setsockopt(fd, SO_SNDTIMEO)", "errno:%d\n",
errno))
return -1;
return 0;
}
static int settcpca(int fd, const char *tcp_ca)
{
int err;
err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca));
if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n",
errno))
return -1;
return 0;
}
static void *server(void *arg)
{
int lfd = (int)(long)arg, err = 0, fd;
ssize_t nr_sent = 0, bytes = 0;
char batch[1500];
fd = accept(lfd, NULL, NULL);
while (fd == -1) {
if (errno == EINTR)
continue;
err = -errno;
goto done;
}
if (settimeo(fd)) {
err = -errno;
goto done;
}
while (bytes < total_bytes && !READ_ONCE(stop)) {
nr_sent = send(fd, &batch,
min(total_bytes - bytes, sizeof(batch)), 0);
if (nr_sent == -1 && errno == EINTR)
continue;
if (nr_sent == -1) {
err = -errno;
break;
}
bytes += nr_sent;
}
CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n",
bytes, total_bytes, nr_sent, errno);
done:
if (fd != -1)
close(fd);
if (err) {
WRITE_ONCE(stop, 1);
return ERR_PTR(err);
}
return NULL;
}
static void do_test(const char *tcp_ca)
{
struct sockaddr_in6 sa6 = {};
ssize_t nr_recv = 0, bytes = 0;
int lfd = -1, fd = -1;
pthread_t srv_thread;
socklen_t addrlen = sizeof(sa6);
void *thread_ret;
char batch[1500];
int err;
WRITE_ONCE(stop, 0);
lfd = socket(AF_INET6, SOCK_STREAM, 0);
if (CHECK(lfd == -1, "socket", "errno:%d\n", errno))
return;
fd = socket(AF_INET6, SOCK_STREAM, 0);
if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) {
close(lfd);
return;
}
if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) ||
settimeo(lfd) || settimeo(fd))
goto done;
/* bind, listen and start server thread to accept */
sa6.sin6_family = AF_INET6;
sa6.sin6_addr = in6addr_loopback;
err = bind(lfd, (struct sockaddr *)&sa6, addrlen);
if (CHECK(err == -1, "bind", "errno:%d\n", errno))
goto done;
err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen);
if (CHECK(err == -1, "getsockname", "errno:%d\n", errno))
goto done;
err = listen(lfd, 1);
if (CHECK(err == -1, "listen", "errno:%d\n", errno))
goto done;
err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd);
if (CHECK(err != 0, "pthread_create", "err:%d\n", err))
goto done;
/* connect to server */
err = connect(fd, (struct sockaddr *)&sa6, addrlen);
if (CHECK(err == -1, "connect", "errno:%d\n", errno))
goto wait_thread;
/* recv total_bytes */
while (bytes < total_bytes && !READ_ONCE(stop)) {
nr_recv = recv(fd, &batch,
min(total_bytes - bytes, sizeof(batch)), 0);
if (nr_recv == -1 && errno == EINTR)
continue;
if (nr_recv == -1)
break;
bytes += nr_recv;
}
CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n",
bytes, total_bytes, nr_recv, errno);
wait_thread:
WRITE_ONCE(stop, 1);
pthread_join(srv_thread, &thread_ret);
CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld",
PTR_ERR(thread_ret));
done:
close(lfd);
close(fd);
}
static void test_dctcp(void)
{
struct bpf_dctcp *dctcp_skel;
struct bpf_link *link;
dctcp_skel = bpf_dctcp__open_and_load();
if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n"))
return;
link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp);
if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n",
PTR_ERR(link))) {
bpf_dctcp__destroy(dctcp_skel);
return;
}
do_test("bpf_dctcp");
bpf_link__destroy(link);
bpf_dctcp__destroy(dctcp_skel);
}
void test_bpf_tcp_ca(void)
{
if (test__start_subtest("dctcp"))
test_dctcp();
}

View File

@ -0,0 +1,210 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2019 Facebook */
/* WARNING: This implemenation is not necessarily the same
* as the tcp_dctcp.c. The purpose is mainly for testing
* the kernel BPF logic.
*/
#include <linux/bpf.h>
#include <linux/types.h>
#include "bpf_tcp_helpers.h"
char _license[] SEC("license") = "GPL";
#define DCTCP_MAX_ALPHA 1024U
struct dctcp {
__u32 old_delivered;
__u32 old_delivered_ce;
__u32 prior_rcv_nxt;
__u32 dctcp_alpha;
__u32 next_seq;
__u32 ce_state;
__u32 loss_cwnd;
};
static unsigned int dctcp_shift_g = 4; /* g = 1/2^4 */
static unsigned int dctcp_alpha_on_init = DCTCP_MAX_ALPHA;
static __always_inline void dctcp_reset(const struct tcp_sock *tp,
struct dctcp *ca)
{
ca->next_seq = tp->snd_nxt;
ca->old_delivered = tp->delivered;
ca->old_delivered_ce = tp->delivered_ce;
}
BPF_TCP_OPS_1(dctcp_init, void, struct sock *, sk)
{
const struct tcp_sock *tp = tcp_sk(sk);
struct dctcp *ca = inet_csk_ca(sk);
ca->prior_rcv_nxt = tp->rcv_nxt;
ca->dctcp_alpha = min(dctcp_alpha_on_init, DCTCP_MAX_ALPHA);
ca->loss_cwnd = 0;
ca->ce_state = 0;
dctcp_reset(tp, ca);
}
BPF_TCP_OPS_1(dctcp_ssthresh, __u32, struct sock *, sk)
{
struct dctcp *ca = inet_csk_ca(sk);
struct tcp_sock *tp = tcp_sk(sk);
ca->loss_cwnd = tp->snd_cwnd;
return max(tp->snd_cwnd - ((tp->snd_cwnd * ca->dctcp_alpha) >> 11U), 2U);
}
BPF_TCP_OPS_2(dctcp_update_alpha, void,
struct sock *, sk, __u32, flags)
{
const struct tcp_sock *tp = tcp_sk(sk);
struct dctcp *ca = inet_csk_ca(sk);
/* Expired RTT */
if (!before(tp->snd_una, ca->next_seq)) {
__u32 delivered_ce = tp->delivered_ce - ca->old_delivered_ce;
__u32 alpha = ca->dctcp_alpha;
/* alpha = (1 - g) * alpha + g * F */
alpha -= min_not_zero(alpha, alpha >> dctcp_shift_g);
if (delivered_ce) {
__u32 delivered = tp->delivered - ca->old_delivered;
/* If dctcp_shift_g == 1, a 32bit value would overflow
* after 8 M packets.
*/
delivered_ce <<= (10 - dctcp_shift_g);
delivered_ce /= max(1U, delivered);
alpha = min(alpha + delivered_ce, DCTCP_MAX_ALPHA);
}
ca->dctcp_alpha = alpha;
dctcp_reset(tp, ca);
}
}
static __always_inline void dctcp_react_to_loss(struct sock *sk)
{
struct dctcp *ca = inet_csk_ca(sk);
struct tcp_sock *tp = tcp_sk(sk);
ca->loss_cwnd = tp->snd_cwnd;
tp->snd_ssthresh = max(tp->snd_cwnd >> 1U, 2U);
}
BPF_TCP_OPS_2(dctcp_state, void, struct sock *, sk, __u8, new_state)
{
if (new_state == TCP_CA_Recovery &&
new_state != BPF_CORE_READ_BITFIELD(inet_csk(sk), icsk_ca_state))
dctcp_react_to_loss(sk);
/* We handle RTO in dctcp_cwnd_event to ensure that we perform only
* one loss-adjustment per RTT.
*/
}
static __always_inline void dctcp_ece_ack_cwr(struct sock *sk, __u32 ce_state)
{
struct tcp_sock *tp = tcp_sk(sk);
if (ce_state == 1)
tp->ecn_flags |= TCP_ECN_DEMAND_CWR;
else
tp->ecn_flags &= ~TCP_ECN_DEMAND_CWR;
}
/* Minimal DCTP CE state machine:
*
* S: 0 <- last pkt was non-CE
* 1 <- last pkt was CE
*/
static __always_inline
void dctcp_ece_ack_update(struct sock *sk, enum tcp_ca_event evt,
__u32 *prior_rcv_nxt, __u32 *ce_state)
{
__u32 new_ce_state = (evt == CA_EVENT_ECN_IS_CE) ? 1 : 0;
if (*ce_state != new_ce_state) {
/* CE state has changed, force an immediate ACK to
* reflect the new CE state. If an ACK was delayed,
* send that first to reflect the prior CE state.
*/
if (inet_csk(sk)->icsk_ack.pending & ICSK_ACK_TIMER) {
dctcp_ece_ack_cwr(sk, *ce_state);
bpf_tcp_send_ack(sk, *prior_rcv_nxt);
}
inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW;
}
*prior_rcv_nxt = tcp_sk(sk)->rcv_nxt;
*ce_state = new_ce_state;
dctcp_ece_ack_cwr(sk, new_ce_state);
}
BPF_TCP_OPS_2(dctcp_cwnd_event, void,
struct sock *, sk, enum tcp_ca_event, ev)
{
struct dctcp *ca = inet_csk_ca(sk);
switch (ev) {
case CA_EVENT_ECN_IS_CE:
case CA_EVENT_ECN_NO_CE:
dctcp_ece_ack_update(sk, ev, &ca->prior_rcv_nxt, &ca->ce_state);
break;
case CA_EVENT_LOSS:
dctcp_react_to_loss(sk);
break;
default:
/* Don't care for the rest. */
break;
}
}
BPF_TCP_OPS_1(dctcp_cwnd_undo, __u32, struct sock *, sk)
{
const struct dctcp *ca = inet_csk_ca(sk);
return max(tcp_sk(sk)->snd_cwnd, ca->loss_cwnd);
}
BPF_TCP_OPS_3(tcp_reno_cong_avoid, void,
struct sock *, sk, __u32, ack, __u32, acked)
{
struct tcp_sock *tp = tcp_sk(sk);
if (!tcp_is_cwnd_limited(sk))
return;
/* In "safe" area, increase. */
if (tcp_in_slow_start(tp)) {
acked = tcp_slow_start(tp, acked);
if (!acked)
return;
}
/* In dangerous area, increase slowly. */
tcp_cong_avoid_ai(tp, tp->snd_cwnd, acked);
}
SEC(".struct_ops")
struct tcp_congestion_ops dctcp_nouse = {
.init = (void *)dctcp_init,
.set_state = (void *)dctcp_state,
.flags = TCP_CONG_NEEDS_ECN,
.name = "bpf_dctcp_nouse",
};
SEC(".struct_ops")
struct tcp_congestion_ops dctcp = {
.init = (void *)dctcp_init,
.in_ack_event = (void *)dctcp_update_alpha,
.cwnd_event = (void *)dctcp_cwnd_event,
.ssthresh = (void *)dctcp_ssthresh,
.cong_avoid = (void *)tcp_reno_cong_avoid,
.undo_cwnd = (void *)dctcp_cwnd_undo,
.set_state = (void *)dctcp_state,
.flags = TCP_CONG_NEEDS_ECN,
.name = "bpf_dctcp",
};