summary refs log tree commit diff
path: root/net/netlink/af_netlink.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r--net/netlink/af_netlink.c95
1 files changed, 59 insertions, 36 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index 05919bf3f670..19909d0786a2 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -116,6 +116,8 @@ static ATOMIC_NOTIFIER_HEAD(netlink_chain);
 static DEFINE_SPINLOCK(netlink_tap_lock);
 static struct list_head netlink_tap_all __read_mostly;
 
+static const struct rhashtable_params netlink_rhashtable_params;
+
 static inline u32 netlink_group_mask(u32 group)
 {
 	return group ? 1 << (group - 1) : 0;
@@ -970,41 +972,50 @@ netlink_unlock_table(void)
 
 struct netlink_compare_arg
 {
-	struct net *net;
+	possible_net_t pnet;
 	u32 portid;
 };
 
-static bool netlink_compare(void *ptr, void *arg)
+/* Doing sizeof directly may yield 4 extra bytes on 64-bit. */
+#define netlink_compare_arg_len \
+	(offsetof(struct netlink_compare_arg, portid) + sizeof(u32))
+
+static inline int netlink_compare(struct rhashtable_compare_arg *arg,
+				  const void *ptr)
 {
-	struct netlink_compare_arg *x = arg;
-	struct sock *sk = ptr;
+	const struct netlink_compare_arg *x = arg->key;
+	const struct netlink_sock *nlk = ptr;
 
-	return nlk_sk(sk)->portid == x->portid &&
-	       net_eq(sock_net(sk), x->net);
+	return nlk->portid != x->portid ||
+	       !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet));
+}
+
+static void netlink_compare_arg_init(struct netlink_compare_arg *arg,
+				     struct net *net, u32 portid)
+{
+	memset(arg, 0, sizeof(*arg));
+	write_pnet(&arg->pnet, net);
+	arg->portid = portid;
 }
 
 static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
 				     struct net *net)
 {
-	struct netlink_compare_arg arg = {
-		.net = net,
-		.portid = portid,
-	};
+	struct netlink_compare_arg arg;
 
-	return rhashtable_lookup_compare(&table->hash, &portid,
-					 &netlink_compare, &arg);
+	netlink_compare_arg_init(&arg, net, portid);
+	return rhashtable_lookup_fast(&table->hash, &arg,
+				      netlink_rhashtable_params);
 }
 
-static bool __netlink_insert(struct netlink_table *table, struct sock *sk)
+static int __netlink_insert(struct netlink_table *table, struct sock *sk)
 {
-	struct netlink_compare_arg arg = {
-		.net = sock_net(sk),
-		.portid = nlk_sk(sk)->portid,
-	};
+	struct netlink_compare_arg arg;
 
-	return rhashtable_lookup_compare_insert(&table->hash,
-						&nlk_sk(sk)->node,
-						&netlink_compare, &arg);
+	netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid);
+	return rhashtable_lookup_insert_key(&table->hash, &arg,
+					    &nlk_sk(sk)->node,
+					    netlink_rhashtable_params);
 }
 
 static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
@@ -1066,9 +1077,10 @@ static int netlink_insert(struct sock *sk, u32 portid)
 	nlk_sk(sk)->portid = portid;
 	sock_hold(sk);
 
-	err = 0;
-	if (!__netlink_insert(table, sk)) {
-		err = -EADDRINUSE;
+	err = __netlink_insert(table, sk);
+	if (err) {
+		if (err == -EEXIST)
+			err = -EADDRINUSE;
 		sock_put(sk);
 	}
 
@@ -1082,7 +1094,8 @@ static void netlink_remove(struct sock *sk)
 	struct netlink_table *table;
 
 	table = &nl_table[sk->sk_protocol];
-	if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node)) {
+	if (!rhashtable_remove_fast(&table->hash, &nlk_sk(sk)->node,
+				    netlink_rhashtable_params)) {
 		WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
 		__sock_put(sk);
 	}
@@ -2256,8 +2269,7 @@ static void netlink_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb)
 	put_cmsg(msg, SOL_NETLINK, NETLINK_PKTINFO, sizeof(info), &info);
 }
 
-static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
-			   struct msghdr *msg, size_t len)
+static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 {
 	struct sock *sk = sock->sk;
 	struct netlink_sock *nlk = nlk_sk(sk);
@@ -2346,8 +2358,7 @@ out:
 	return err;
 }
 
-static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock,
-			   struct msghdr *msg, size_t len,
+static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 			   int flags)
 {
 	struct scm_cookie scm;
@@ -3116,17 +3127,28 @@ static struct pernet_operations __net_initdata netlink_net_ops = {
 	.exit = netlink_net_exit,
 };
 
+static inline u32 netlink_hash(const void *data, u32 len, u32 seed)
+{
+	const struct netlink_sock *nlk = data;
+	struct netlink_compare_arg arg;
+
+	netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid);
+	return jhash2((u32 *)&arg, netlink_compare_arg_len / sizeof(u32), seed);
+}
+
+static const struct rhashtable_params netlink_rhashtable_params = {
+	.head_offset = offsetof(struct netlink_sock, node),
+	.key_len = netlink_compare_arg_len,
+	.obj_hashfn = netlink_hash,
+	.obj_cmpfn = netlink_compare,
+	.max_size = 65536,
+	.automatic_shrinking = true,
+};
+
 static int __init netlink_proto_init(void)
 {
 	int i;
 	int err = proto_register(&netlink_proto, 0);
-	struct rhashtable_params ht_params = {
-		.head_offset = offsetof(struct netlink_sock, node),
-		.key_offset = offsetof(struct netlink_sock, portid),
-		.key_len = sizeof(u32), /* portid */
-		.hashfn = jhash,
-		.max_shift = 16, /* 64K */
-	};
 
 	if (err != 0)
 		goto out;
@@ -3138,7 +3160,8 @@ static int __init netlink_proto_init(void)
 		goto panic;
 
 	for (i = 0; i < MAX_LINKS; i++) {
-		if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
+		if (rhashtable_init(&nl_table[i].hash,
+				    &netlink_rhashtable_params) < 0) {
 			while (--i > 0)
 				rhashtable_destroy(&nl_table[i].hash);
 			kfree(nl_table);