summary refs log tree commit diff
path: root/kernel/bpf/sockmap.c
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/bpf/sockmap.c')
-rw-r--r--kernel/bpf/sockmap.c148
1 files changed, 88 insertions, 60 deletions
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 098eca568c2b..beab9ec9b023 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -48,14 +48,18 @@
 #define SOCK_CREATE_FLAG_MASK \
 	(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
 
-struct bpf_stab {
-	struct bpf_map map;
-	struct sock **sock_map;
+struct bpf_sock_progs {
 	struct bpf_prog *bpf_tx_msg;
 	struct bpf_prog *bpf_parse;
 	struct bpf_prog *bpf_verdict;
 };
 
+struct bpf_stab {
+	struct bpf_map map;
+	struct sock **sock_map;
+	struct bpf_sock_progs progs;
+};
+
 enum smap_psock_state {
 	SMAP_TX_RUNNING,
 };
@@ -461,7 +465,7 @@ static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
 {
 	return ((_rc == SK_PASS) ?
-	       (md->map ? __SK_REDIRECT : __SK_PASS) :
+	       (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
 	       __SK_DROP);
 }
 
@@ -1092,7 +1096,7 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
 	 * when we orphan the skb so that we don't have the possibility
 	 * to reference a stale map.
 	 */
-	TCP_SKB_CB(skb)->bpf.map = NULL;
+	TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
 	skb->sk = psock->sock;
 	bpf_compute_data_pointers(skb);
 	preempt_disable();
@@ -1102,7 +1106,7 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
 
 	/* Moving return codes from UAPI namespace into internal namespace */
 	return rc == SK_PASS ?
-		(TCP_SKB_CB(skb)->bpf.map ? __SK_REDIRECT : __SK_PASS) :
+		(TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
 		__SK_DROP;
 }
 
@@ -1372,7 +1376,6 @@ static int smap_init_sock(struct smap_psock *psock,
 }
 
 static void smap_init_progs(struct smap_psock *psock,
-			    struct bpf_stab *stab,
 			    struct bpf_prog *verdict,
 			    struct bpf_prog *parse)
 {
@@ -1450,14 +1453,13 @@ static void smap_gc_work(struct work_struct *w)
 	kfree(psock);
 }
 
-static struct smap_psock *smap_init_psock(struct sock *sock,
-					  struct bpf_stab *stab)
+static struct smap_psock *smap_init_psock(struct sock *sock, int node)
 {
 	struct smap_psock *psock;
 
 	psock = kzalloc_node(sizeof(struct smap_psock),
 			     GFP_ATOMIC | __GFP_NOWARN,
-			     stab->map.numa_node);
+			     node);
 	if (!psock)
 		return ERR_PTR(-ENOMEM);
 
@@ -1662,40 +1664,26 @@ out:
  *  - sock_map must use READ_ONCE and (cmp)xchg operations
  *  - BPF verdict/parse programs must use READ_ONCE and xchg operations
  */
-static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
-				    struct bpf_map *map,
-				    void *key, u64 flags)
+
+static int __sock_map_ctx_update_elem(struct bpf_map *map,
+				      struct bpf_sock_progs *progs,
+				      struct sock *sock,
+				      struct sock **map_link,
+				      void *key)
 {
-	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
-	struct smap_psock_map_entry *e = NULL;
 	struct bpf_prog *verdict, *parse, *tx_msg;
-	struct sock *osock, *sock;
+	struct smap_psock_map_entry *e = NULL;
 	struct smap_psock *psock;
-	u32 i = *(u32 *)key;
 	bool new = false;
 	int err;
 
-	if (unlikely(flags > BPF_EXIST))
-		return -EINVAL;
-
-	if (unlikely(i >= stab->map.max_entries))
-		return -E2BIG;
-
-	sock = READ_ONCE(stab->sock_map[i]);
-	if (flags == BPF_EXIST && !sock)
-		return -ENOENT;
-	else if (flags == BPF_NOEXIST && sock)
-		return -EEXIST;
-
-	sock = skops->sk;
-
 	/* 1. If sock map has BPF programs those will be inherited by the
 	 * sock being added. If the sock is already attached to BPF programs
 	 * this results in an error.
 	 */
-	verdict = READ_ONCE(stab->bpf_verdict);
-	parse = READ_ONCE(stab->bpf_parse);
-	tx_msg = READ_ONCE(stab->bpf_tx_msg);
+	verdict = READ_ONCE(progs->bpf_verdict);
+	parse = READ_ONCE(progs->bpf_parse);
+	tx_msg = READ_ONCE(progs->bpf_tx_msg);
 
 	if (parse && verdict) {
 		/* bpf prog refcnt may be zero if a concurrent attach operation
@@ -1703,11 +1691,11 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 		 * we increment the refcnt. If this is the case abort with an
 		 * error.
 		 */
-		verdict = bpf_prog_inc_not_zero(stab->bpf_verdict);
+		verdict = bpf_prog_inc_not_zero(progs->bpf_verdict);
 		if (IS_ERR(verdict))
 			return PTR_ERR(verdict);
 
-		parse = bpf_prog_inc_not_zero(stab->bpf_parse);
+		parse = bpf_prog_inc_not_zero(progs->bpf_parse);
 		if (IS_ERR(parse)) {
 			bpf_prog_put(verdict);
 			return PTR_ERR(parse);
@@ -1715,7 +1703,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 	}
 
 	if (tx_msg) {
-		tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);
+		tx_msg = bpf_prog_inc_not_zero(progs->bpf_tx_msg);
 		if (IS_ERR(tx_msg)) {
 			if (verdict)
 				bpf_prog_put(verdict);
@@ -1748,7 +1736,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 			goto out_progs;
 		}
 	} else {
-		psock = smap_init_psock(sock, stab);
+		psock = smap_init_psock(sock, map->numa_node);
 		if (IS_ERR(psock)) {
 			err = PTR_ERR(psock);
 			goto out_progs;
@@ -1763,7 +1751,6 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 		err = -ENOMEM;
 		goto out_progs;
 	}
-	e->entry = &stab->sock_map[i];
 
 	/* 3. At this point we have a reference to a valid psock that is
 	 * running. Attach any BPF programs needed.
@@ -1780,7 +1767,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 		err = smap_init_sock(psock, sock);
 		if (err)
 			goto out_free;
-		smap_init_progs(psock, stab, verdict, parse);
+		smap_init_progs(psock, verdict, parse);
 		smap_start_sock(psock, sock);
 	}
 
@@ -1789,19 +1776,12 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 	 * it with. Because we can only have a single set of programs if
 	 * old_sock has a strp we can stop it.
 	 */
-	list_add_tail(&e->list, &psock->maps);
-	write_unlock_bh(&sock->sk_callback_lock);
-
-	osock = xchg(&stab->sock_map[i], sock);
-	if (osock) {
-		struct smap_psock *opsock = smap_psock_sk(osock);
-
-		write_lock_bh(&osock->sk_callback_lock);
-		smap_list_remove(opsock, &stab->sock_map[i]);
-		smap_release_sock(opsock, osock);
-		write_unlock_bh(&osock->sk_callback_lock);
+	if (map_link) {
+		e->entry = map_link;
+		list_add_tail(&e->list, &psock->maps);
 	}
-	return 0;
+	write_unlock_bh(&sock->sk_callback_lock);
+	return err;
 out_free:
 	smap_release_sock(psock, sock);
 out_progs:
@@ -1816,23 +1796,69 @@ out_progs:
 	return err;
 }
 
-int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
+static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
+				    struct bpf_map *map,
+				    void *key, u64 flags)
 {
 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+	struct bpf_sock_progs *progs = &stab->progs;
+	struct sock *osock, *sock;
+	u32 i = *(u32 *)key;
+	int err;
+
+	if (unlikely(flags > BPF_EXIST))
+		return -EINVAL;
+
+	if (unlikely(i >= stab->map.max_entries))
+		return -E2BIG;
+
+	sock = READ_ONCE(stab->sock_map[i]);
+	if (flags == BPF_EXIST && !sock)
+		return -ENOENT;
+	else if (flags == BPF_NOEXIST && sock)
+		return -EEXIST;
+
+	sock = skops->sk;
+	err = __sock_map_ctx_update_elem(map, progs, sock, &stab->sock_map[i],
+					 key);
+	if (err)
+		goto out;
+
+	osock = xchg(&stab->sock_map[i], sock);
+	if (osock) {
+		struct smap_psock *opsock = smap_psock_sk(osock);
+
+		write_lock_bh(&osock->sk_callback_lock);
+		smap_list_remove(opsock, &stab->sock_map[i]);
+		smap_release_sock(opsock, osock);
+		write_unlock_bh(&osock->sk_callback_lock);
+	}
+out:
+	return 0;
+}
+
+int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
+{
+	struct bpf_sock_progs *progs;
 	struct bpf_prog *orig;
 
-	if (unlikely(map->map_type != BPF_MAP_TYPE_SOCKMAP))
+	if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
+		struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+
+		progs = &stab->progs;
+	} else {
 		return -EINVAL;
+	}
 
 	switch (type) {
 	case BPF_SK_MSG_VERDICT:
-		orig = xchg(&stab->bpf_tx_msg, prog);
+		orig = xchg(&progs->bpf_tx_msg, prog);
 		break;
 	case BPF_SK_SKB_STREAM_PARSER:
-		orig = xchg(&stab->bpf_parse, prog);
+		orig = xchg(&progs->bpf_parse, prog);
 		break;
 	case BPF_SK_SKB_STREAM_VERDICT:
-		orig = xchg(&stab->bpf_verdict, prog);
+		orig = xchg(&progs->bpf_verdict, prog);
 		break;
 	default:
 		return -EOPNOTSUPP;
@@ -1881,16 +1907,18 @@ static int sock_map_update_elem(struct bpf_map *map,
 static void sock_map_release(struct bpf_map *map)
 {
 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+	struct bpf_sock_progs *progs;
 	struct bpf_prog *orig;
 
-	orig = xchg(&stab->bpf_parse, NULL);
+	progs = &stab->progs;
+	orig = xchg(&progs->bpf_parse, NULL);
 	if (orig)
 		bpf_prog_put(orig);
-	orig = xchg(&stab->bpf_verdict, NULL);
+	orig = xchg(&progs->bpf_verdict, NULL);
 	if (orig)
 		bpf_prog_put(orig);
 
-	orig = xchg(&stab->bpf_tx_msg, NULL);
+	orig = xchg(&progs->bpf_tx_msg, NULL);
 	if (orig)
 		bpf_prog_put(orig);
 }