summary refs log tree commit diff
path: root/net/ipv4/tcp_ipv4.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/tcp_ipv4.c')
-rw-r--r--net/ipv4/tcp_ipv4.c45
1 files changed, 32 insertions, 13 deletions
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 2e62e0d6373a..5b8ce65dfc06 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1037,6 +1037,20 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req)
 DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
 EXPORT_SYMBOL(tcp_md5_needed);
 
+static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new)
+{
+	if (!old)
+		return true;
+
+	/* l3index always overrides non-l3index */
+	if (old->l3index && new->l3index == 0)
+		return false;
+	if (old->l3index == 0 && new->l3index)
+		return true;
+
+	return old->prefixlen < new->prefixlen;
+}
+
 /* Find the Key structure for an address.  */
 struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
 					   const union tcp_md5_addr *addr,
@@ -1059,7 +1073,7 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
 				 lockdep_sock_is_held(sk)) {
 		if (key->family != family)
 			continue;
-		if (key->l3index && key->l3index != l3index)
+		if (key->flags & TCP_MD5SIG_FLAG_IFINDEX && key->l3index != l3index)
 			continue;
 		if (family == AF_INET) {
 			mask = inet_make_mask(key->prefixlen);
@@ -1074,8 +1088,7 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
 			match = false;
 		}
 
-		if (match && (!best_match ||
-			      key->prefixlen > best_match->prefixlen))
+		if (match && better_md5_match(best_match, key))
 			best_match = key;
 	}
 	return best_match;
@@ -1085,7 +1098,7 @@ EXPORT_SYMBOL(__tcp_md5_do_lookup);
 static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
 						      const union tcp_md5_addr *addr,
 						      int family, u8 prefixlen,
-						      int l3index)
+						      int l3index, u8 flags)
 {
 	const struct tcp_sock *tp = tcp_sk(sk);
 	struct tcp_md5sig_key *key;
@@ -1105,7 +1118,9 @@ static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
 				 lockdep_sock_is_held(sk)) {
 		if (key->family != family)
 			continue;
-		if (key->l3index && key->l3index != l3index)
+		if ((key->flags & TCP_MD5SIG_FLAG_IFINDEX) != (flags & TCP_MD5SIG_FLAG_IFINDEX))
+			continue;
+		if (key->l3index != l3index)
 			continue;
 		if (!memcmp(&key->addr, addr, size) &&
 		    key->prefixlen == prefixlen)
@@ -1129,7 +1144,7 @@ EXPORT_SYMBOL(tcp_v4_md5_lookup);
 
 /* This can be called on a newly created socket, from other files */
 int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
-		   int family, u8 prefixlen, int l3index,
+		   int family, u8 prefixlen, int l3index, u8 flags,
 		   const u8 *newkey, u8 newkeylen, gfp_t gfp)
 {
 	/* Add Key to the list */
@@ -1137,7 +1152,7 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct tcp_md5sig_info *md5sig;
 
-	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index);
+	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags);
 	if (key) {
 		/* Pre-existing entry - just update that one.
 		 * Note that the key might be used concurrently.
@@ -1182,6 +1197,7 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 	key->family = family;
 	key->prefixlen = prefixlen;
 	key->l3index = l3index;
+	key->flags = flags;
 	memcpy(&key->addr, addr,
 	       (family == AF_INET6) ? sizeof(struct in6_addr) :
 				      sizeof(struct in_addr));
@@ -1191,11 +1207,11 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 EXPORT_SYMBOL(tcp_md5_do_add);
 
 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
-		   u8 prefixlen, int l3index)
+		   u8 prefixlen, int l3index, u8 flags)
 {
 	struct tcp_md5sig_key *key;
 
-	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index);
+	key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags);
 	if (!key)
 		return -ENOENT;
 	hlist_del_rcu(&key->node);
@@ -1229,6 +1245,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
 	const union tcp_md5_addr *addr;
 	u8 prefixlen = 32;
 	int l3index = 0;
+	u8 flags;
 
 	if (optlen < sizeof(cmd))
 		return -EINVAL;
@@ -1239,6 +1256,8 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
 	if (sin->sin_family != AF_INET)
 		return -EINVAL;
 
+	flags = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX;
+
 	if (optname == TCP_MD5SIG_EXT &&
 	    cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) {
 		prefixlen = cmd.tcpm_prefixlen;
@@ -1246,7 +1265,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
 			return -EINVAL;
 	}
 
-	if (optname == TCP_MD5SIG_EXT &&
+	if (optname == TCP_MD5SIG_EXT && cmd.tcpm_ifindex &&
 	    cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX) {
 		struct net_device *dev;
 
@@ -1267,12 +1286,12 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
 	addr = (union tcp_md5_addr *)&sin->sin_addr.s_addr;
 
 	if (!cmd.tcpm_keylen)
-		return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index);
+		return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index, flags);
 
 	if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
 		return -EINVAL;
 
-	return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index,
+	return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
 			      cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
 }
 
@@ -1596,7 +1615,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
 		 * memory, then we end up not copying the key
 		 * across. Shucks.
 		 */
-		tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index,
+		tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
 			       key->key, key->keylen, GFP_ATOMIC);
 		sk_nocaps_add(newsk, NETIF_F_GSO_MASK);
 	}