summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--drivers/net/wireguard/allowedips.c132
-rw-r--r--drivers/net/wireguard/allowedips.h9
2 files changed, 57 insertions, 84 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index 3725e9cd85f4..2785cfd3a221 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -66,60 +66,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
 	}
 }
 
-static void walk_remove_by_peer(struct allowedips_node __rcu **top,
-				struct wg_peer *peer, struct mutex *lock)
-{
-#define REF(p) rcu_access_pointer(p)
-#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
-#define PUSH(p) ({                                                             \
-		WARN_ON(IS_ENABLED(DEBUG) && len >= 128);                      \
-		stack[len++] = p;                                              \
-	})
-
-	struct allowedips_node __rcu **stack[128], **nptr;
-	struct allowedips_node *node, *prev;
-	unsigned int len;
-
-	if (unlikely(!peer || !REF(*top)))
-		return;
-
-	for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
-		nptr = stack[len - 1];
-		node = DEREF(nptr);
-		if (!node) {
-			--len;
-			continue;
-		}
-		if (!prev || REF(prev->bit[0]) == node ||
-		    REF(prev->bit[1]) == node) {
-			if (REF(node->bit[0]))
-				PUSH(&node->bit[0]);
-			else if (REF(node->bit[1]))
-				PUSH(&node->bit[1]);
-		} else if (REF(node->bit[0]) == prev) {
-			if (REF(node->bit[1]))
-				PUSH(&node->bit[1]);
-		} else {
-			if (rcu_dereference_protected(node->peer,
-				lockdep_is_held(lock)) == peer) {
-				RCU_INIT_POINTER(node->peer, NULL);
-				list_del_init(&node->peer_list);
-				if (!node->bit[0] || !node->bit[1]) {
-					rcu_assign_pointer(*nptr, DEREF(
-					       &node->bit[!REF(node->bit[0])]));
-					kfree_rcu(node, rcu);
-					node = DEREF(nptr);
-				}
-			}
-			--len;
-		}
-	}
-
-#undef REF
-#undef DEREF
-#undef PUSH
-}
-
 static unsigned int fls128(u64 a, u64 b)
 {
 	return a ? fls64(a) + 64U : fls64(b);
@@ -224,6 +170,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
 		RCU_INIT_POINTER(node->peer, peer);
 		list_add_tail(&node->peer_list, &peer->allowedips_list);
 		copy_and_assign_cidr(node, key, cidr, bits);
+		rcu_assign_pointer(node->parent_bit, trie);
 		rcu_assign_pointer(*trie, node);
 		return 0;
 	}
@@ -243,9 +190,9 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
 	if (!node) {
 		down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
 	} else {
-		down = rcu_dereference_protected(CHOOSE_NODE(node, key),
-						 lockdep_is_held(lock));
+		down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock));
 		if (!down) {
+			rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key));
 			rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
 			return 0;
 		}
@@ -254,29 +201,37 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
 	parent = node;
 
 	if (newnode->cidr == cidr) {
+		rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits));
 		rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
-		if (!parent)
+		if (!parent) {
+			rcu_assign_pointer(newnode->parent_bit, trie);
 			rcu_assign_pointer(*trie, newnode);
-		else
-			rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
-					   newnode);
-	} else {
-		node = kzalloc(sizeof(*node), GFP_KERNEL);
-		if (unlikely(!node)) {
-			list_del(&newnode->peer_list);
-			kfree(newnode);
-			return -ENOMEM;
+		} else {
+			rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits));
+			rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode);
 		}
-		INIT_LIST_HEAD(&node->peer_list);
-		copy_and_assign_cidr(node, newnode->bits, cidr, bits);
-
-		rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
-		rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
-		if (!parent)
-			rcu_assign_pointer(*trie, node);
-		else
-			rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
-					   node);
+		return 0;
+	}
+
+	node = kzalloc(sizeof(*node), GFP_KERNEL);
+	if (unlikely(!node)) {
+		list_del(&newnode->peer_list);
+		kfree(newnode);
+		return -ENOMEM;
+	}
+	INIT_LIST_HEAD(&node->peer_list);
+	copy_and_assign_cidr(node, newnode->bits, cidr, bits);
+
+	rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits));
+	rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
+	rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits));
+	rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
+	if (!parent) {
+		rcu_assign_pointer(node->parent_bit, trie);
+		rcu_assign_pointer(*trie, node);
+	} else {
+		rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits));
+		rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node);
 	}
 	return 0;
 }
@@ -335,9 +290,30 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
 void wg_allowedips_remove_by_peer(struct allowedips *table,
 				  struct wg_peer *peer, struct mutex *lock)
 {
+	struct allowedips_node *node, *child, *tmp;
+
+	if (list_empty(&peer->allowedips_list))
+		return;
 	++table->seq;
-	walk_remove_by_peer(&table->root4, peer, lock);
-	walk_remove_by_peer(&table->root6, peer, lock);
+	list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
+		list_del_init(&node->peer_list);
+		RCU_INIT_POINTER(node->peer, NULL);
+		if (node->bit[0] && node->bit[1])
+			continue;
+		child = rcu_dereference_protected(
+				node->bit[!rcu_access_pointer(node->bit[0])],
+				lockdep_is_held(lock));
+		if (child)
+			child->parent_bit = node->parent_bit;
+		*rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child;
+		kfree_rcu(node, rcu);
+
+		/* TODO: Note that we currently don't walk up and down in order to
+		 * free any potential filler nodes. This means that this function
+		 * doesn't free up as much as it could, which could be revisited
+		 * at some point.
+		 */
+	}
 }
 
 int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h
index e5c83cafcef4..f08f552e6852 100644
--- a/drivers/net/wireguard/allowedips.h
+++ b/drivers/net/wireguard/allowedips.h
@@ -15,14 +15,11 @@ struct wg_peer;
 struct allowedips_node {
 	struct wg_peer __rcu *peer;
 	struct allowedips_node __rcu *bit[2];
-	/* While it may seem scandalous that we waste space for v4,
-	 * we're alloc'ing to the nearest power of 2 anyway, so this
-	 * doesn't actually make a difference.
-	 */
-	u8 bits[16] __aligned(__alignof(u64));
 	u8 cidr, bit_at_a, bit_at_b, bitlen;
+	u8 bits[16] __aligned(__alignof(u64));
 
-	/* Keep rarely used list at bottom to be beyond cache line. */
+	/* Keep rarely used members at bottom to be beyond cache line. */
+	struct allowedips_node *__rcu *parent_bit; /* XXX: this puts us at 68->128 bytes instead of 60->64 bytes!! */
 	union {
 		struct list_head peer_list;
 		struct rcu_head rcu;