summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--include/net/inet_connection_sock.h3
-rw-r--r--include/net/inet_hashtables.h80
-rw-r--r--include/net/sock.h14
-rw-r--r--net/dccp/ipv4.c25
-rw-r--r--net/dccp/ipv6.c18
-rw-r--r--net/dccp/proto.c34
-rw-r--r--net/ipv4/af_inet.c26
-rw-r--r--net/ipv4/inet_connection_sock.c275
-rw-r--r--net/ipv4/inet_hashtables.c268
-rw-r--r--net/ipv4/tcp.c11
-rw-r--r--net/ipv4/tcp_ipv4.c23
-rw-r--r--net/ipv6/tcp_ipv6.c17
12 files changed, 700 insertions, 94 deletions
diff --git a/include/net/inet_connection_sock.h b/include/net/inet_connection_sock.h
index ee88f0f1350f..c2b15f7e5516 100644
--- a/include/net/inet_connection_sock.h
+++ b/include/net/inet_connection_sock.h
@@ -25,6 +25,7 @@
 #undef INET_CSK_CLEAR_TIMERS
 
 struct inet_bind_bucket;
+struct inet_bind2_bucket;
 struct tcp_congestion_ops;
 
 /*
@@ -57,6 +58,7 @@ struct inet_connection_sock_af_ops {
  *
  * @icsk_accept_queue:	   FIFO of established children
  * @icsk_bind_hash:	   Bind node
+ * @icsk_bind2_hash:	   Bind node in the bhash2 table
  * @icsk_timeout:	   Timeout
  * @icsk_retransmit_timer: Resend (no ack)
  * @icsk_rto:		   Retransmit timeout
@@ -83,6 +85,7 @@ struct inet_connection_sock {
 	struct inet_sock	  icsk_inet;
 	struct request_sock_queue icsk_accept_queue;
 	struct inet_bind_bucket	  *icsk_bind_hash;
+	struct inet_bind2_bucket  *icsk_bind2_hash;
 	unsigned long		  icsk_timeout;
  	struct timer_list	  icsk_retransmit_timer;
  	struct timer_list	  icsk_delack_timer;
diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index e9cf2157ed8a..44a419b9e3d5 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -23,6 +23,7 @@
 
 #include <net/inet_connection_sock.h>
 #include <net/inet_sock.h>
+#include <net/ip.h>
 #include <net/sock.h>
 #include <net/route.h>
 #include <net/tcp_states.h>
@@ -90,7 +91,28 @@ struct inet_bind_bucket {
 	struct hlist_head	owners;
 };
 
-static inline struct net *ib_net(struct inet_bind_bucket *ib)
+struct inet_bind2_bucket {
+	possible_net_t		ib_net;
+	int			l3mdev;
+	unsigned short		port;
+	union {
+#if IS_ENABLED(CONFIG_IPV6)
+		struct in6_addr		v6_rcv_saddr;
+#endif
+		__be32			rcv_saddr;
+	};
+	/* Node in the bhash2 inet_bind_hashbucket chain */
+	struct hlist_node	node;
+	/* List of sockets hashed to this bucket */
+	struct hlist_head	owners;
+};
+
+static inline struct net *ib_net(const struct inet_bind_bucket *ib)
+{
+	return read_pnet(&ib->ib_net);
+}
+
+static inline struct net *ib2_net(const struct inet_bind2_bucket *ib)
 {
 	return read_pnet(&ib->ib_net);
 }
@@ -133,7 +155,14 @@ struct inet_hashinfo {
 	 * TCP hash as well as the others for fast bind/connect.
 	 */
 	struct kmem_cache		*bind_bucket_cachep;
+	/* This bind table is hashed by local port */
 	struct inet_bind_hashbucket	*bhash;
+	struct kmem_cache		*bind2_bucket_cachep;
+	/* This bind table is hashed by local port and sk->sk_rcv_saddr (ipv4)
+	 * or sk->sk_v6_rcv_saddr (ipv6). This 2nd bind table is used
+	 * primarily for expediting bind conflict resolution.
+	 */
+	struct inet_bind_hashbucket	*bhash2;
 	unsigned int			bhash_size;
 
 	/* The 2nd listener table hashed by local port and address */
@@ -182,14 +211,61 @@ inet_bind_bucket_create(struct kmem_cache *cachep, struct net *net,
 void inet_bind_bucket_destroy(struct kmem_cache *cachep,
 			      struct inet_bind_bucket *tb);
 
+bool inet_bind_bucket_match(const struct inet_bind_bucket *tb,
+			    const struct net *net, unsigned short port,
+			    int l3mdev);
+
+struct inet_bind2_bucket *
+inet_bind2_bucket_create(struct kmem_cache *cachep, struct net *net,
+			 struct inet_bind_hashbucket *head,
+			 unsigned short port, int l3mdev,
+			 const struct sock *sk);
+
+void inet_bind2_bucket_destroy(struct kmem_cache *cachep,
+			       struct inet_bind2_bucket *tb);
+
+struct inet_bind2_bucket *
+inet_bind2_bucket_find(const struct inet_bind_hashbucket *head,
+		       const struct net *net,
+		       unsigned short port, int l3mdev,
+		       const struct sock *sk);
+
+bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb,
+				      const struct net *net, unsigned short port,
+				      int l3mdev, const struct sock *sk);
+
 static inline u32 inet_bhashfn(const struct net *net, const __u16 lport,
 			       const u32 bhash_size)
 {
 	return (lport + net_hash_mix(net)) & (bhash_size - 1);
 }
 
+static inline struct inet_bind_hashbucket *
+inet_bhashfn_portaddr(const struct inet_hashinfo *hinfo, const struct sock *sk,
+		      const struct net *net, unsigned short port)
+{
+	u32 hash;
+
+#if IS_ENABLED(CONFIG_IPV6)
+	if (sk->sk_family == AF_INET6)
+		hash = ipv6_portaddr_hash(net, &sk->sk_v6_rcv_saddr, port);
+	else
+#endif
+		hash = ipv4_portaddr_hash(net, sk->sk_rcv_saddr, port);
+	return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
+}
+
+struct inet_bind_hashbucket *
+inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port);
+
+/* This should be called whenever a socket's sk_rcv_saddr (ipv4) or
+ * sk_v6_rcv_saddr (ipv6) changes after it has been binded. The socket's
+ * rcv_saddr field should already have been updated when this is called.
+ */
+int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk);
+
 void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
-		    const unsigned short snum);
+		    struct inet_bind2_bucket *tb2, unsigned short port);
 
 /* Caller must disable local BH processing. */
 int __inet_inherit_port(const struct sock *sk, struct sock *child);
diff --git a/include/net/sock.h b/include/net/sock.h
index d08cfe190a78..ca469980e006 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -348,6 +348,7 @@ struct sk_filter;
   *	@sk_txtime_report_errors: set report errors mode for SO_TXTIME
   *	@sk_txtime_unused: unused txtime flags
   *	@ns_tracker: tracker for netns reference
+  *	@sk_bind2_node: bind node in the bhash2 table
   */
 struct sock {
 	/*
@@ -537,6 +538,7 @@ struct sock {
 #endif
 	struct rcu_head		sk_rcu;
 	netns_tracker		ns_tracker;
+	struct hlist_node	sk_bind2_node;
 };
 
 enum sk_pacing {
@@ -870,6 +872,16 @@ static inline void sk_add_bind_node(struct sock *sk,
 	hlist_add_head(&sk->sk_bind_node, list);
 }
 
+static inline void __sk_del_bind2_node(struct sock *sk)
+{
+	__hlist_del(&sk->sk_bind2_node);
+}
+
+static inline void sk_add_bind2_node(struct sock *sk, struct hlist_head *list)
+{
+	hlist_add_head(&sk->sk_bind2_node, list);
+}
+
 #define sk_for_each(__sk, list) \
 	hlist_for_each_entry(__sk, list, sk_node)
 #define sk_for_each_rcu(__sk, list) \
@@ -887,6 +899,8 @@ static inline void sk_add_bind_node(struct sock *sk,
 	hlist_for_each_entry_safe(__sk, tmp, list, sk_node)
 #define sk_for_each_bound(__sk, list) \
 	hlist_for_each_entry(__sk, list, sk_bind_node)
+#define sk_for_each_bound_bhash2(__sk, list) \
+	hlist_for_each_entry(__sk, list, sk_bind2_node)
 
 /**
  * sk_for_each_entry_offset_rcu - iterate over a list at a given struct offset
diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c
index da6e3b20cd75..6a6e121dc00c 100644
--- a/net/dccp/ipv4.c
+++ b/net/dccp/ipv4.c
@@ -45,10 +45,11 @@ static unsigned int dccp_v4_pernet_id __read_mostly;
 int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 {
 	const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
+	struct inet_bind_hashbucket *prev_addr_hashbucket = NULL;
+	__be32 daddr, nexthop, prev_sk_rcv_saddr;
 	struct inet_sock *inet = inet_sk(sk);
 	struct dccp_sock *dp = dccp_sk(sk);
 	__be16 orig_sport, orig_dport;
-	__be32 daddr, nexthop;
 	struct flowi4 *fl4;
 	struct rtable *rt;
 	int err;
@@ -89,9 +90,29 @@ int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	if (inet_opt == NULL || !inet_opt->opt.srr)
 		daddr = fl4->daddr;
 
-	if (inet->inet_saddr == 0)
+	if (inet->inet_saddr == 0) {
+		if (inet_csk(sk)->icsk_bind2_hash) {
+			prev_addr_hashbucket =
+				inet_bhashfn_portaddr(&dccp_hashinfo, sk,
+						      sock_net(sk),
+						      inet->inet_num);
+			prev_sk_rcv_saddr = sk->sk_rcv_saddr;
+		}
 		inet->inet_saddr = fl4->saddr;
+	}
+
 	sk_rcv_saddr_set(sk, inet->inet_saddr);
+
+	if (prev_addr_hashbucket) {
+		err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+		if (err) {
+			inet->inet_saddr = 0;
+			sk_rcv_saddr_set(sk, prev_sk_rcv_saddr);
+			ip_rt_put(rt);
+			return err;
+		}
+	}
+
 	inet->inet_dport = usin->sin_port;
 	sk_daddr_set(sk, daddr);
 
diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c
index fd44638ec16b..e57b43006074 100644
--- a/net/dccp/ipv6.c
+++ b/net/dccp/ipv6.c
@@ -934,8 +934,26 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	}
 
 	if (saddr == NULL) {
+		struct inet_bind_hashbucket *prev_addr_hashbucket = NULL;
+		struct in6_addr prev_v6_rcv_saddr;
+
+		if (icsk->icsk_bind2_hash) {
+			prev_addr_hashbucket = inet_bhashfn_portaddr(&dccp_hashinfo,
+								     sk, sock_net(sk),
+								     inet->inet_num);
+			prev_v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+		}
+
 		saddr = &fl6.saddr;
 		sk->sk_v6_rcv_saddr = *saddr;
+
+		if (prev_addr_hashbucket) {
+			err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+			if (err) {
+				sk->sk_v6_rcv_saddr = prev_v6_rcv_saddr;
+				goto failure;
+			}
+		}
 	}
 
 	/* set the source address */
diff --git a/net/dccp/proto.c b/net/dccp/proto.c
index e13641c65f88..7cd4a6cc99fc 100644
--- a/net/dccp/proto.c
+++ b/net/dccp/proto.c
@@ -1120,6 +1120,12 @@ static int __init dccp_init(void)
 				  SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, NULL);
 	if (!dccp_hashinfo.bind_bucket_cachep)
 		goto out_free_hashinfo2;
+	dccp_hashinfo.bind2_bucket_cachep =
+		kmem_cache_create("dccp_bind2_bucket",
+				  sizeof(struct inet_bind2_bucket), 0,
+				  SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, NULL);
+	if (!dccp_hashinfo.bind2_bucket_cachep)
+		goto out_free_bind_bucket_cachep;
 
 	/*
 	 * Size and allocate the main established and bind bucket
@@ -1150,7 +1156,7 @@ static int __init dccp_init(void)
 
 	if (!dccp_hashinfo.ehash) {
 		DCCP_CRIT("Failed to allocate DCCP established hash table");
-		goto out_free_bind_bucket_cachep;
+		goto out_free_bind2_bucket_cachep;
 	}
 
 	for (i = 0; i <= dccp_hashinfo.ehash_mask; i++)
@@ -1176,14 +1182,24 @@ static int __init dccp_init(void)
 		goto out_free_dccp_locks;
 	}
 
+	dccp_hashinfo.bhash2 = (struct inet_bind_hashbucket *)
+		__get_free_pages(GFP_ATOMIC | __GFP_NOWARN, bhash_order);
+
+	if (!dccp_hashinfo.bhash2) {
+		DCCP_CRIT("Failed to allocate DCCP bind2 hash table");
+		goto out_free_dccp_bhash;
+	}
+
 	for (i = 0; i < dccp_hashinfo.bhash_size; i++) {
 		spin_lock_init(&dccp_hashinfo.bhash[i].lock);
 		INIT_HLIST_HEAD(&dccp_hashinfo.bhash[i].chain);
+		spin_lock_init(&dccp_hashinfo.bhash2[i].lock);
+		INIT_HLIST_HEAD(&dccp_hashinfo.bhash2[i].chain);
 	}
 
 	rc = dccp_mib_init();
 	if (rc)
-		goto out_free_dccp_bhash;
+		goto out_free_dccp_bhash2;
 
 	rc = dccp_ackvec_init();
 	if (rc)
@@ -1207,30 +1223,38 @@ out_ackvec_exit:
 	dccp_ackvec_exit();
 out_free_dccp_mib:
 	dccp_mib_exit();
+out_free_dccp_bhash2:
+	free_pages((unsigned long)dccp_hashinfo.bhash2, bhash_order);
 out_free_dccp_bhash:
 	free_pages((unsigned long)dccp_hashinfo.bhash, bhash_order);
 out_free_dccp_locks:
 	inet_ehash_locks_free(&dccp_hashinfo);
 out_free_dccp_ehash:
 	free_pages((unsigned long)dccp_hashinfo.ehash, ehash_order);
+out_free_bind2_bucket_cachep:
+	kmem_cache_destroy(dccp_hashinfo.bind2_bucket_cachep);
 out_free_bind_bucket_cachep:
 	kmem_cache_destroy(dccp_hashinfo.bind_bucket_cachep);
 out_free_hashinfo2:
 	inet_hashinfo2_free_mod(&dccp_hashinfo);
 out_fail:
 	dccp_hashinfo.bhash = NULL;
+	dccp_hashinfo.bhash2 = NULL;
 	dccp_hashinfo.ehash = NULL;
 	dccp_hashinfo.bind_bucket_cachep = NULL;
+	dccp_hashinfo.bind2_bucket_cachep = NULL;
 	return rc;
 }
 
 static void __exit dccp_fini(void)
 {
+	int bhash_order = get_order(dccp_hashinfo.bhash_size *
+				    sizeof(struct inet_bind_hashbucket));
+
 	ccid_cleanup_builtins();
 	dccp_mib_exit();
-	free_pages((unsigned long)dccp_hashinfo.bhash,
-		   get_order(dccp_hashinfo.bhash_size *
-			     sizeof(struct inet_bind_hashbucket)));
+	free_pages((unsigned long)dccp_hashinfo.bhash, bhash_order);
+	free_pages((unsigned long)dccp_hashinfo.bhash2, bhash_order);
 	free_pages((unsigned long)dccp_hashinfo.ehash,
 		   get_order((dccp_hashinfo.ehash_mask + 1) *
 			     sizeof(struct inet_ehash_bucket)));
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 3ca0cc467886..2e6a3cb06815 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -1219,6 +1219,7 @@ EXPORT_SYMBOL(inet_unregister_protosw);
 
 static int inet_sk_reselect_saddr(struct sock *sk)
 {
+	struct inet_bind_hashbucket *prev_addr_hashbucket;
 	struct inet_sock *inet = inet_sk(sk);
 	__be32 old_saddr = inet->inet_saddr;
 	__be32 daddr = inet->inet_daddr;
@@ -1226,6 +1227,7 @@ static int inet_sk_reselect_saddr(struct sock *sk)
 	struct rtable *rt;
 	__be32 new_saddr;
 	struct ip_options_rcu *inet_opt;
+	int err;
 
 	inet_opt = rcu_dereference_protected(inet->inet_opt,
 					     lockdep_sock_is_held(sk));
@@ -1240,20 +1242,34 @@ static int inet_sk_reselect_saddr(struct sock *sk)
 	if (IS_ERR(rt))
 		return PTR_ERR(rt);
 
-	sk_setup_caps(sk, &rt->dst);
-
 	new_saddr = fl4->saddr;
 
-	if (new_saddr == old_saddr)
+	if (new_saddr == old_saddr) {
+		sk_setup_caps(sk, &rt->dst);
 		return 0;
+	}
+
+	prev_addr_hashbucket =
+		inet_bhashfn_portaddr(sk->sk_prot->h.hashinfo, sk,
+				      sock_net(sk), inet->inet_num);
+
+	inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
+
+	err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+	if (err) {
+		inet->inet_saddr = old_saddr;
+		inet->inet_rcv_saddr = old_saddr;
+		ip_rt_put(rt);
+		return err;
+	}
+
+	sk_setup_caps(sk, &rt->dst);
 
 	if (READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_dynaddr) > 1) {
 		pr_info("%s(): shifting inet->saddr from %pI4 to %pI4\n",
 			__func__, &old_saddr, &new_saddr);
 	}
 
-	inet->inet_saddr = inet->inet_rcv_saddr = new_saddr;
-
 	/*
 	 * XXX The only one ugly spot where we need to
 	 * XXX really change the sockets identity after
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index eb31c7158b39..f0038043b661 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -130,14 +130,75 @@ void inet_get_local_port_range(struct net *net, int *low, int *high)
 }
 EXPORT_SYMBOL(inet_get_local_port_range);
 
+static bool inet_use_bhash2_on_bind(const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+	if (sk->sk_family == AF_INET6) {
+		int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
+
+		return addr_type != IPV6_ADDR_ANY &&
+			addr_type != IPV6_ADDR_MAPPED;
+	}
+#endif
+	return sk->sk_rcv_saddr != htonl(INADDR_ANY);
+}
+
+static bool inet_bind_conflict(const struct sock *sk, struct sock *sk2,
+			       kuid_t sk_uid, bool relax,
+			       bool reuseport_cb_ok, bool reuseport_ok)
+{
+	int bound_dev_if2;
+
+	if (sk == sk2)
+		return false;
+
+	bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
+
+	if (!sk->sk_bound_dev_if || !bound_dev_if2 ||
+	    sk->sk_bound_dev_if == bound_dev_if2) {
+		if (sk->sk_reuse && sk2->sk_reuse &&
+		    sk2->sk_state != TCP_LISTEN) {
+			if (!relax || (!reuseport_ok && sk->sk_reuseport &&
+				       sk2->sk_reuseport && reuseport_cb_ok &&
+				       (sk2->sk_state == TCP_TIME_WAIT ||
+					uid_eq(sk_uid, sock_i_uid(sk2)))))
+				return true;
+		} else if (!reuseport_ok || !sk->sk_reuseport ||
+			   !sk2->sk_reuseport || !reuseport_cb_ok ||
+			   (sk2->sk_state != TCP_TIME_WAIT &&
+			    !uid_eq(sk_uid, sock_i_uid(sk2)))) {
+			return true;
+		}
+	}
+	return false;
+}
+
+static bool inet_bhash2_conflict(const struct sock *sk,
+				 const struct inet_bind2_bucket *tb2,
+				 kuid_t sk_uid,
+				 bool relax, bool reuseport_cb_ok,
+				 bool reuseport_ok)
+{
+	struct sock *sk2;
+
+	sk_for_each_bound_bhash2(sk2, &tb2->owners) {
+		if (sk->sk_family == AF_INET && ipv6_only_sock(sk2))
+			continue;
+
+		if (inet_bind_conflict(sk, sk2, sk_uid, relax,
+				       reuseport_cb_ok, reuseport_ok))
+			return true;
+	}
+	return false;
+}
+
+/* This should be called only when the tb and tb2 hashbuckets' locks are held */
 static int inet_csk_bind_conflict(const struct sock *sk,
 				  const struct inet_bind_bucket *tb,
+				  const struct inet_bind2_bucket *tb2, /* may be null */
 				  bool relax, bool reuseport_ok)
 {
-	struct sock *sk2;
 	bool reuseport_cb_ok;
-	bool reuse = sk->sk_reuse;
-	bool reuseport = !!sk->sk_reuseport;
 	struct sock_reuseport *reuseport_cb;
 	kuid_t uid = sock_i_uid((struct sock *)sk);
 
@@ -150,55 +211,87 @@ static int inet_csk_bind_conflict(const struct sock *sk,
 	/*
 	 * Unlike other sk lookup places we do not check
 	 * for sk_net here, since _all_ the socks listed
-	 * in tb->owners list belong to the same net - the
-	 * one this bucket belongs to.
+	 * in tb->owners and tb2->owners list belong
+	 * to the same net - the one this bucket belongs to.
 	 */
 
-	sk_for_each_bound(sk2, &tb->owners) {
-		int bound_dev_if2;
+	if (!inet_use_bhash2_on_bind(sk)) {
+		struct sock *sk2;
 
-		if (sk == sk2)
-			continue;
-		bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if);
-		if ((!sk->sk_bound_dev_if ||
-		     !bound_dev_if2 ||
-		     sk->sk_bound_dev_if == bound_dev_if2)) {
-			if (reuse && sk2->sk_reuse &&
-			    sk2->sk_state != TCP_LISTEN) {
-				if ((!relax ||
-				     (!reuseport_ok &&
-				      reuseport && sk2->sk_reuseport &&
-				      reuseport_cb_ok &&
-				      (sk2->sk_state == TCP_TIME_WAIT ||
-				       uid_eq(uid, sock_i_uid(sk2))))) &&
-				    inet_rcv_saddr_equal(sk, sk2, true))
-					break;
-			} else if (!reuseport_ok ||
-				   !reuseport || !sk2->sk_reuseport ||
-				   !reuseport_cb_ok ||
-				   (sk2->sk_state != TCP_TIME_WAIT &&
-				    !uid_eq(uid, sock_i_uid(sk2)))) {
-				if (inet_rcv_saddr_equal(sk, sk2, true))
-					break;
-			}
-		}
+		sk_for_each_bound(sk2, &tb->owners)
+			if (inet_bind_conflict(sk, sk2, uid, relax,
+					       reuseport_cb_ok, reuseport_ok) &&
+			    inet_rcv_saddr_equal(sk, sk2, true))
+				return true;
+
+		return false;
+	}
+
+	/* Conflicts with an existing IPV6_ADDR_ANY (if ipv6) or INADDR_ANY (if
+	 * ipv4) should have been checked already. We need to do these two
+	 * checks separately because their spinlocks have to be acquired/released
+	 * independently of each other, to prevent possible deadlocks
+	 */
+	return tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+					   reuseport_ok);
+}
+
+/* Determine if there is a bind conflict with an existing IPV6_ADDR_ANY (if ipv6) or
+ * INADDR_ANY (if ipv4) socket.
+ *
+ * Caller must hold bhash hashbucket lock with local bh disabled, to protect
+ * against concurrent binds on the port for addr any
+ */
+static bool inet_bhash2_addr_any_conflict(const struct sock *sk, int port, int l3mdev,
+					  bool relax, bool reuseport_ok)
+{
+	kuid_t uid = sock_i_uid((struct sock *)sk);
+	const struct net *net = sock_net(sk);
+	struct sock_reuseport *reuseport_cb;
+	struct inet_bind_hashbucket *head2;
+	struct inet_bind2_bucket *tb2;
+	bool reuseport_cb_ok;
+
+	rcu_read_lock();
+	reuseport_cb = rcu_dereference(sk->sk_reuseport_cb);
+	/* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */
+	reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks);
+	rcu_read_unlock();
+
+	head2 = inet_bhash2_addr_any_hashbucket(sk, net, port);
+
+	spin_lock(&head2->lock);
+
+	inet_bind_bucket_for_each(tb2, &head2->chain)
+		if (inet_bind2_bucket_match_addr_any(tb2, net, port, l3mdev, sk))
+			break;
+
+	if (tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok,
+					reuseport_ok)) {
+		spin_unlock(&head2->lock);
+		return true;
 	}
-	return sk2 != NULL;
+
+	spin_unlock(&head2->lock);
+	return false;
 }
 
 /*
  * Find an open port number for the socket.  Returns with the
- * inet_bind_hashbucket lock held.
+ * inet_bind_hashbucket locks held if successful.
  */
 static struct inet_bind_hashbucket *
-inet_csk_find_open_port(struct sock *sk, struct inet_bind_bucket **tb_ret, int *port_ret)
+inet_csk_find_open_port(const struct sock *sk, struct inet_bind_bucket **tb_ret,
+			struct inet_bind2_bucket **tb2_ret,
+			struct inet_bind_hashbucket **head2_ret, int *port_ret)
 {
 	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
 	int port = 0;
-	struct inet_bind_hashbucket *head;
+	struct inet_bind_hashbucket *head, *head2;
 	struct net *net = sock_net(sk);
 	bool relax = false;
 	int i, low, high, attempt_half;
+	struct inet_bind2_bucket *tb2;
 	struct inet_bind_bucket *tb;
 	u32 remaining, offset;
 	int l3mdev;
@@ -239,11 +332,20 @@ other_parity_scan:
 		head = &hinfo->bhash[inet_bhashfn(net, port,
 						  hinfo->bhash_size)];
 		spin_lock_bh(&head->lock);
+		if (inet_use_bhash2_on_bind(sk)) {
+			if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, relax, false))
+				goto next_port;
+		}
+
+		head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+		spin_lock(&head2->lock);
+		tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
 		inet_bind_bucket_for_each(tb, &head->chain)
-			if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
-			    tb->port == port) {
-				if (!inet_csk_bind_conflict(sk, tb, relax, false))
+			if (inet_bind_bucket_match(tb, net, port, l3mdev)) {
+				if (!inet_csk_bind_conflict(sk, tb, tb2,
+							    relax, false))
 					goto success;
+				spin_unlock(&head2->lock);
 				goto next_port;
 			}
 		tb = NULL;
@@ -272,6 +374,8 @@ next_port:
 success:
 	*port_ret = port;
 	*tb_ret = tb;
+	*tb2_ret = tb2;
+	*head2_ret = head2;
 	return head;
 }
 
@@ -368,53 +472,95 @@ int inet_csk_get_port(struct sock *sk, unsigned short snum)
 	bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN;
 	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
 	int ret = 1, port = snum;
-	struct inet_bind_hashbucket *head;
 	struct net *net = sock_net(sk);
+	bool found_port = false, check_bind_conflict = true;
+	bool bhash_created = false, bhash2_created = false;
+	struct inet_bind_hashbucket *head, *head2;
+	struct inet_bind2_bucket *tb2 = NULL;
 	struct inet_bind_bucket *tb = NULL;
+	bool head2_lock_acquired = false;
 	int l3mdev;
 
 	l3mdev = inet_sk_bound_l3mdev(sk);
 
 	if (!port) {
-		head = inet_csk_find_open_port(sk, &tb, &port);
+		head = inet_csk_find_open_port(sk, &tb, &tb2, &head2, &port);
 		if (!head)
 			return ret;
+
+		head2_lock_acquired = true;
+
+		if (tb && tb2)
+			goto success;
+		found_port = true;
+	} else {
+		head = &hinfo->bhash[inet_bhashfn(net, port,
+						  hinfo->bhash_size)];
+		spin_lock_bh(&head->lock);
+		inet_bind_bucket_for_each(tb, &head->chain)
+			if (inet_bind_bucket_match(tb, net, port, l3mdev))
+				break;
+	}
+
+	if (!tb) {
+		tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, net,
+					     head, port, l3mdev);
 		if (!tb)
-			goto tb_not_found;
-		goto success;
+			goto fail_unlock;
+		bhash_created = true;
 	}
-	head = &hinfo->bhash[inet_bhashfn(net, port,
-					  hinfo->bhash_size)];
-	spin_lock_bh(&head->lock);
-	inet_bind_bucket_for_each(tb, &head->chain)
-		if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
-		    tb->port == port)
-			goto tb_found;
-tb_not_found:
-	tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep,
-				     net, head, port, l3mdev);
-	if (!tb)
-		goto fail_unlock;
-tb_found:
-	if (!hlist_empty(&tb->owners)) {
-		if (sk->sk_reuse == SK_FORCE_REUSE)
-			goto success;
 
-		if ((tb->fastreuse > 0 && reuse) ||
-		    sk_reuseport_match(tb, sk))
-			goto success;
-		if (inet_csk_bind_conflict(sk, tb, true, true))
+	if (!found_port) {
+		if (!hlist_empty(&tb->owners)) {
+			if (sk->sk_reuse == SK_FORCE_REUSE ||
+			    (tb->fastreuse > 0 && reuse) ||
+			    sk_reuseport_match(tb, sk))
+				check_bind_conflict = false;
+		}
+
+		if (check_bind_conflict && inet_use_bhash2_on_bind(sk)) {
+			if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, true, true))
+				goto fail_unlock;
+		}
+
+		head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+		spin_lock(&head2->lock);
+		head2_lock_acquired = true;
+		tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+	}
+
+	if (!tb2) {
+		tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep,
+					       net, head2, port, l3mdev, sk);
+		if (!tb2)
 			goto fail_unlock;
+		bhash2_created = true;
 	}
+
+	if (!found_port && check_bind_conflict) {
+		if (inet_csk_bind_conflict(sk, tb, tb2, true, true))
+			goto fail_unlock;
+	}
+
 success:
 	inet_csk_update_fastreuse(tb, sk);
 
 	if (!inet_csk(sk)->icsk_bind_hash)
-		inet_bind_hash(sk, tb, port);
+		inet_bind_hash(sk, tb, tb2, port);
 	WARN_ON(inet_csk(sk)->icsk_bind_hash != tb);
+	WARN_ON(inet_csk(sk)->icsk_bind2_hash != tb2);
 	ret = 0;
 
 fail_unlock:
+	if (ret) {
+		if (bhash_created)
+			inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+		if (bhash2_created)
+			inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+						  tb2);
+	}
+	if (head2_lock_acquired)
+		spin_unlock(&head2->lock);
 	spin_unlock_bh(&head->lock);
 	return ret;
 }
@@ -962,6 +1108,7 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
 
 		inet_sk_set_state(newsk, TCP_SYN_RECV);
 		newicsk->icsk_bind_hash = NULL;
+		newicsk->icsk_bind2_hash = NULL;
 
 		inet_sk(newsk)->inet_dport = inet_rsk(req)->ir_rmt_port;
 		inet_sk(newsk)->inet_num = inet_rsk(req)->ir_num;
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index b9d995b5ce24..60d77e234a68 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -92,12 +92,75 @@ void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket
 	}
 }
 
+bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net,
+			    unsigned short port, int l3mdev)
+{
+	return net_eq(ib_net(tb), net) && tb->port == port &&
+		tb->l3mdev == l3mdev;
+}
+
+static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb,
+				   struct net *net,
+				   struct inet_bind_hashbucket *head,
+				   unsigned short port, int l3mdev,
+				   const struct sock *sk)
+{
+	write_pnet(&tb->ib_net, net);
+	tb->l3mdev    = l3mdev;
+	tb->port      = port;
+#if IS_ENABLED(CONFIG_IPV6)
+	if (sk->sk_family == AF_INET6)
+		tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+	else
+#endif
+		tb->rcv_saddr = sk->sk_rcv_saddr;
+	INIT_HLIST_HEAD(&tb->owners);
+	hlist_add_head(&tb->node, &head->chain);
+}
+
+struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
+						   struct net *net,
+						   struct inet_bind_hashbucket *head,
+						   unsigned short port,
+						   int l3mdev,
+						   const struct sock *sk)
+{
+	struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
+
+	if (tb)
+		inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk);
+
+	return tb;
+}
+
+/* Caller must hold hashbucket lock for this tb with local BH disabled */
+void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb)
+{
+	if (hlist_empty(&tb->owners)) {
+		__hlist_del(&tb->node);
+		kmem_cache_free(cachep, tb);
+	}
+}
+
+static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2,
+					 const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+	if (sk->sk_family == AF_INET6)
+		return ipv6_addr_equal(&tb2->v6_rcv_saddr,
+				       &sk->sk_v6_rcv_saddr);
+#endif
+	return tb2->rcv_saddr == sk->sk_rcv_saddr;
+}
+
 void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
-		    const unsigned short snum)
+		    struct inet_bind2_bucket *tb2, unsigned short port)
 {
-	inet_sk(sk)->inet_num = snum;
+	inet_sk(sk)->inet_num = port;
 	sk_add_bind_node(sk, &tb->owners);
 	inet_csk(sk)->icsk_bind_hash = tb;
+	sk_add_bind2_node(sk, &tb2->owners);
+	inet_csk(sk)->icsk_bind2_hash = tb2;
 }
 
 /*
@@ -109,6 +172,9 @@ static void __inet_put_port(struct sock *sk)
 	const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
 			hashinfo->bhash_size);
 	struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
+	struct inet_bind_hashbucket *head2 =
+		inet_bhashfn_portaddr(hashinfo, sk, sock_net(sk),
+				      inet_sk(sk)->inet_num);
 	struct inet_bind_bucket *tb;
 
 	spin_lock(&head->lock);
@@ -117,6 +183,17 @@ static void __inet_put_port(struct sock *sk)
 	inet_csk(sk)->icsk_bind_hash = NULL;
 	inet_sk(sk)->inet_num = 0;
 	inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
+
+	spin_lock(&head2->lock);
+	if (inet_csk(sk)->icsk_bind2_hash) {
+		struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash;
+
+		__sk_del_bind2_node(sk);
+		inet_csk(sk)->icsk_bind2_hash = NULL;
+		inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2);
+	}
+	spin_unlock(&head2->lock);
+
 	spin_unlock(&head->lock);
 }
 
@@ -135,12 +212,21 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
 	const int bhash = inet_bhashfn(sock_net(sk), port,
 			table->bhash_size);
 	struct inet_bind_hashbucket *head = &table->bhash[bhash];
+	struct inet_bind_hashbucket *head2 =
+		inet_bhashfn_portaddr(table, child, sock_net(sk), port);
+	bool created_inet_bind_bucket = false;
+	bool update_fastreuse = false;
+	struct net *net = sock_net(sk);
+	struct inet_bind2_bucket *tb2;
 	struct inet_bind_bucket *tb;
 	int l3mdev;
 
 	spin_lock(&head->lock);
+	spin_lock(&head2->lock);
 	tb = inet_csk(sk)->icsk_bind_hash;
-	if (unlikely(!tb)) {
+	tb2 = inet_csk(sk)->icsk_bind2_hash;
+	if (unlikely(!tb || !tb2)) {
+		spin_unlock(&head2->lock);
 		spin_unlock(&head->lock);
 		return -ENOENT;
 	}
@@ -153,25 +239,49 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
 		 * as that of the child socket. We have to look up or
 		 * create a new bind bucket for the child here. */
 		inet_bind_bucket_for_each(tb, &head->chain) {
-			if (net_eq(ib_net(tb), sock_net(sk)) &&
-			    tb->l3mdev == l3mdev && tb->port == port)
+			if (inet_bind_bucket_match(tb, net, port, l3mdev))
 				break;
 		}
 		if (!tb) {
 			tb = inet_bind_bucket_create(table->bind_bucket_cachep,
-						     sock_net(sk), head, port,
-						     l3mdev);
+						     net, head, port, l3mdev);
 			if (!tb) {
+				spin_unlock(&head2->lock);
 				spin_unlock(&head->lock);
 				return -ENOMEM;
 			}
+			created_inet_bind_bucket = true;
+		}
+		update_fastreuse = true;
+
+		goto bhash2_find;
+	} else if (!inet_bind2_bucket_addr_match(tb2, child)) {
+		l3mdev = inet_sk_bound_l3mdev(sk);
+
+bhash2_find:
+		tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child);
+		if (!tb2) {
+			tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep,
+						       net, head2, port,
+						       l3mdev, child);
+			if (!tb2)
+				goto error;
 		}
-		inet_csk_update_fastreuse(tb, child);
 	}
-	inet_bind_hash(child, tb, port);
+	if (update_fastreuse)
+		inet_csk_update_fastreuse(tb, child);
+	inet_bind_hash(child, tb, tb2, port);
+	spin_unlock(&head2->lock);
 	spin_unlock(&head->lock);
 
 	return 0;
+
+error:
+	if (created_inet_bind_bucket)
+		inet_bind_bucket_destroy(table->bind_bucket_cachep, tb);
+	spin_unlock(&head2->lock);
+	spin_unlock(&head->lock);
+	return -ENOMEM;
 }
 EXPORT_SYMBOL_GPL(__inet_inherit_port);
 
@@ -675,6 +785,112 @@ void inet_unhash(struct sock *sk)
 }
 EXPORT_SYMBOL_GPL(inet_unhash);
 
+static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb,
+				    const struct net *net, unsigned short port,
+				    int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+	if (sk->sk_family == AF_INET6)
+		return net_eq(ib2_net(tb), net) && tb->port == port &&
+			tb->l3mdev == l3mdev &&
+			ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr);
+	else
+#endif
+		return net_eq(ib2_net(tb), net) && tb->port == port &&
+			tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr;
+}
+
+bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net,
+				      unsigned short port, int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+	struct in6_addr addr_any = {};
+
+	if (sk->sk_family == AF_INET6)
+		return net_eq(ib2_net(tb), net) && tb->port == port &&
+			tb->l3mdev == l3mdev &&
+			ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any);
+	else
+#endif
+		return net_eq(ib2_net(tb), net) && tb->port == port &&
+			tb->l3mdev == l3mdev && tb->rcv_saddr == 0;
+}
+
+/* The socket's bhash2 hashbucket spinlock must be held when this is called */
+struct inet_bind2_bucket *
+inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net,
+		       unsigned short port, int l3mdev, const struct sock *sk)
+{
+	struct inet_bind2_bucket *bhash2 = NULL;
+
+	inet_bind_bucket_for_each(bhash2, &head->chain)
+		if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
+			break;
+
+	return bhash2;
+}
+
+struct inet_bind_hashbucket *
+inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port)
+{
+	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+	u32 hash;
+#if IS_ENABLED(CONFIG_IPV6)
+	struct in6_addr addr_any = {};
+
+	if (sk->sk_family == AF_INET6)
+		hash = ipv6_portaddr_hash(net, &addr_any, port);
+	else
+#endif
+		hash = ipv4_portaddr_hash(net, 0, port);
+
+	return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
+}
+
+int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk)
+{
+	struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+	struct inet_bind2_bucket *tb2, *new_tb2;
+	int l3mdev = inet_sk_bound_l3mdev(sk);
+	struct inet_bind_hashbucket *head2;
+	int port = inet_sk(sk)->inet_num;
+	struct net *net = sock_net(sk);
+
+	/* Allocate a bind2 bucket ahead of time to avoid permanently putting
+	 * the bhash2 table in an inconsistent state if a new tb2 bucket
+	 * allocation fails.
+	 */
+	new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC);
+	if (!new_tb2)
+		return -ENOMEM;
+
+	head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+
+	if (prev_saddr) {
+		spin_lock_bh(&prev_saddr->lock);
+		__sk_del_bind2_node(sk);
+		inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+					  inet_csk(sk)->icsk_bind2_hash);
+		spin_unlock_bh(&prev_saddr->lock);
+	}
+
+	spin_lock_bh(&head2->lock);
+	tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+	if (!tb2) {
+		tb2 = new_tb2;
+		inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk);
+	}
+	sk_add_bind2_node(sk, &tb2->owners);
+	inet_csk(sk)->icsk_bind2_hash = tb2;
+	spin_unlock_bh(&head2->lock);
+
+	if (tb2 != new_tb2)
+		kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2);
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr);
+
 /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
  * Note that we use 32bit integers (vs RFC 'short integers')
  * because 2^16 is not a multiple of num_ephemeral and this
@@ -694,11 +910,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
 			struct sock *, __u16, struct inet_timewait_sock **))
 {
 	struct inet_hashinfo *hinfo = death_row->hashinfo;
+	struct inet_bind_hashbucket *head, *head2;
 	struct inet_timewait_sock *tw = NULL;
-	struct inet_bind_hashbucket *head;
 	int port = inet_sk(sk)->inet_num;
 	struct net *net = sock_net(sk);
+	struct inet_bind2_bucket *tb2;
 	struct inet_bind_bucket *tb;
+	bool tb_created = false;
 	u32 remaining, offset;
 	int ret, i, low, high;
 	int l3mdev;
@@ -755,8 +973,7 @@ other_parity_scan:
 		 * the established check is already unique enough.
 		 */
 		inet_bind_bucket_for_each(tb, &head->chain) {
-			if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
-			    tb->port == port) {
+			if (inet_bind_bucket_match(tb, net, port, l3mdev)) {
 				if (tb->fastreuse >= 0 ||
 				    tb->fastreuseport >= 0)
 					goto next_port;
@@ -774,6 +991,7 @@ other_parity_scan:
 			spin_unlock_bh(&head->lock);
 			return -ENOMEM;
 		}
+		tb_created = true;
 		tb->fastreuse = -1;
 		tb->fastreuseport = -1;
 		goto ok;
@@ -789,6 +1007,20 @@ next_port:
 	return -EADDRNOTAVAIL;
 
 ok:
+	/* Find the corresponding tb2 bucket since we need to
+	 * add the socket to the bhash2 table as well
+	 */
+	head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+	spin_lock(&head2->lock);
+
+	tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+	if (!tb2) {
+		tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net,
+					       head2, port, l3mdev, sk);
+		if (!tb2)
+			goto error;
+	}
+
 	/* Here we want to add a little bit of randomness to the next source
 	 * port that will be chosen. We use a max() with a random here so that
 	 * on low contention the randomness is maximal and on high contention
@@ -798,7 +1030,10 @@ ok:
 	WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2);
 
 	/* Head lock still held and bh's disabled */
-	inet_bind_hash(sk, tb, port);
+	inet_bind_hash(sk, tb, tb2, port);
+
+	spin_unlock(&head2->lock);
+
 	if (sk_unhashed(sk)) {
 		inet_sk(sk)->inet_sport = htons(port);
 		inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
@@ -810,6 +1045,13 @@ ok:
 		inet_twsk_deschedule_put(tw);
 	local_bh_enable();
 	return 0;
+
+error:
+	spin_unlock(&head2->lock);
+	if (tb_created)
+		inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+	spin_unlock_bh(&head->lock);
+	return -ENOMEM;
 }
 
 /*
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index ba62f8e8bd15..baf6adb723ad 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -4742,6 +4742,12 @@ void __init tcp_init(void)
 				  SLAB_HWCACHE_ALIGN | SLAB_PANIC |
 				  SLAB_ACCOUNT,
 				  NULL);
+	tcp_hashinfo.bind2_bucket_cachep =
+		kmem_cache_create("tcp_bind2_bucket",
+				  sizeof(struct inet_bind2_bucket), 0,
+				  SLAB_HWCACHE_ALIGN | SLAB_PANIC |
+				  SLAB_ACCOUNT,
+				  NULL);
 
 	/* Size and allocate the main established and bind bucket
 	 * hash tables.
@@ -4765,7 +4771,7 @@ void __init tcp_init(void)
 		panic("TCP: failed to alloc ehash_locks");
 	tcp_hashinfo.bhash =
 		alloc_large_system_hash("TCP bind",
-					sizeof(struct inet_bind_hashbucket),
+					2 * sizeof(struct inet_bind_hashbucket),
 					tcp_hashinfo.ehash_mask + 1,
 					17, /* one slot per 128 KB of memory */
 					0,
@@ -4774,9 +4780,12 @@ void __init tcp_init(void)
 					0,
 					64 * 1024);
 	tcp_hashinfo.bhash_size = 1U << tcp_hashinfo.bhash_size;
+	tcp_hashinfo.bhash2 = tcp_hashinfo.bhash + tcp_hashinfo.bhash_size;
 	for (i = 0; i < tcp_hashinfo.bhash_size; i++) {
 		spin_lock_init(&tcp_hashinfo.bhash[i].lock);
 		INIT_HLIST_HEAD(&tcp_hashinfo.bhash[i].chain);
+		spin_lock_init(&tcp_hashinfo.bhash2[i].lock);
+		INIT_HLIST_HEAD(&tcp_hashinfo.bhash2[i].chain);
 	}
 
 
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 0c83780dc9bf..cc2ad67f75be 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -199,11 +199,12 @@ static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr,
 /* This will initiate an outgoing connection. */
 int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 {
+	struct inet_bind_hashbucket *prev_addr_hashbucket = NULL;
 	struct sockaddr_in *usin = (struct sockaddr_in *)uaddr;
+	__be32 daddr, nexthop, prev_sk_rcv_saddr;
 	struct inet_sock *inet = inet_sk(sk);
 	struct tcp_sock *tp = tcp_sk(sk);
 	__be16 orig_sport, orig_dport;
-	__be32 daddr, nexthop;
 	struct flowi4 *fl4;
 	struct rtable *rt;
 	int err;
@@ -246,10 +247,28 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 	if (!inet_opt || !inet_opt->opt.srr)
 		daddr = fl4->daddr;
 
-	if (!inet->inet_saddr)
+	if (!inet->inet_saddr) {
+		if (inet_csk(sk)->icsk_bind2_hash) {
+			prev_addr_hashbucket = inet_bhashfn_portaddr(&tcp_hashinfo,
+								     sk, sock_net(sk),
+								     inet->inet_num);
+			prev_sk_rcv_saddr = sk->sk_rcv_saddr;
+		}
 		inet->inet_saddr = fl4->saddr;
+	}
+
 	sk_rcv_saddr_set(sk, inet->inet_saddr);
 
+	if (prev_addr_hashbucket) {
+		err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+		if (err) {
+			inet->inet_saddr = 0;
+			sk_rcv_saddr_set(sk, prev_sk_rcv_saddr);
+			ip_rt_put(rt);
+			return err;
+		}
+	}
+
 	if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) {
 		/* Reset inherited state */
 		tp->rx_opt.ts_recent	   = 0;
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index e54eee80ce5f..ff5c4fc135fc 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -287,8 +287,25 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 	}
 
 	if (!saddr) {
+		struct inet_bind_hashbucket *prev_addr_hashbucket = NULL;
+		struct in6_addr prev_v6_rcv_saddr;
+
+		if (icsk->icsk_bind2_hash) {
+			prev_addr_hashbucket = inet_bhashfn_portaddr(&tcp_hashinfo,
+								     sk, sock_net(sk),
+								     inet->inet_num);
+			prev_v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+		}
 		saddr = &fl6.saddr;
 		sk->sk_v6_rcv_saddr = *saddr;
+
+		if (prev_addr_hashbucket) {
+			err = inet_bhash2_update_saddr(prev_addr_hashbucket, sk);
+			if (err) {
+				sk->sk_v6_rcv_saddr = prev_v6_rcv_saddr;
+				goto failure;
+			}
+		}
 	}
 
 	/* set the source address */