summary refs log tree commit diff
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/core/sock.c18
-rw-r--r--net/sched/cls_cgroup.c50
-rw-r--r--net/socket.c9
3 files changed, 61 insertions, 16 deletions
diff --git a/net/core/sock.c b/net/core/sock.c
index bf88a167c8f2..a05ae7f9771e 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -123,6 +123,7 @@
 #include <linux/net_tstamp.h>
 #include <net/xfrm.h>
 #include <linux/ipsec.h>
+#include <net/cls_cgroup.h>
 
 #include <linux/filter.h>
 
@@ -217,6 +218,11 @@ __u32 sysctl_rmem_default __read_mostly = SK_RMEM_MAX;
 int sysctl_optmem_max __read_mostly = sizeof(unsigned long)*(2*UIO_MAXIOV+512);
 EXPORT_SYMBOL(sysctl_optmem_max);
 
+#if defined(CONFIG_CGROUPS) && !defined(CONFIG_NET_CLS_CGROUP)
+int net_cls_subsys_id = -1;
+EXPORT_SYMBOL_GPL(net_cls_subsys_id);
+#endif
+
 static int sock_set_timeout(long *timeo_p, char __user *optval, int optlen)
 {
 	struct timeval tv;
@@ -1050,6 +1056,16 @@ static void sk_prot_free(struct proto *prot, struct sock *sk)
 	module_put(owner);
 }
 
+#ifdef CONFIG_CGROUPS
+void sock_update_classid(struct sock *sk)
+{
+	u32 classid = task_cls_classid(current);
+
+	if (classid && classid != sk->sk_classid)
+		sk->sk_classid = classid;
+}
+#endif
+
 /**
  *	sk_alloc - All socket objects are allocated here
  *	@net: the applicable net namespace
@@ -1073,6 +1089,8 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
 		sock_lock_init(sk);
 		sock_net_set(sk, get_net(net));
 		atomic_set(&sk->sk_wmem_alloc, 1);
+
+		sock_update_classid(sk);
 	}
 
 	return sk;
diff --git a/net/sched/cls_cgroup.c b/net/sched/cls_cgroup.c
index 221180384fd7..78ef2c5e130b 100644
--- a/net/sched/cls_cgroup.c
+++ b/net/sched/cls_cgroup.c
@@ -16,14 +16,11 @@
 #include <linux/errno.h>
 #include <linux/skbuff.h>
 #include <linux/cgroup.h>
+#include <linux/rcupdate.h>
 #include <net/rtnetlink.h>
 #include <net/pkt_cls.h>
-
-struct cgroup_cls_state
-{
-	struct cgroup_subsys_state css;
-	u32 classid;
-};
+#include <net/sock.h>
+#include <net/cls_cgroup.h>
 
 static struct cgroup_subsys_state *cgrp_create(struct cgroup_subsys *ss,
 					       struct cgroup *cgrp);
@@ -112,6 +109,10 @@ static int cls_cgroup_classify(struct sk_buff *skb, struct tcf_proto *tp,
 	struct cls_cgroup_head *head = tp->root;
 	u32 classid;
 
+	rcu_read_lock();
+	classid = task_cls_state(current)->classid;
+	rcu_read_unlock();
+
 	/*
 	 * Due to the nature of the classifier it is required to ignore all
 	 * packets originating from softirq context as accessing `current'
@@ -122,12 +123,12 @@ static int cls_cgroup_classify(struct sk_buff *skb, struct tcf_proto *tp,
 	 * calls by looking at the number of nested bh disable calls because
 	 * softirqs always disables bh.
 	 */
-	if (softirq_count() != SOFTIRQ_OFFSET)
-		return -1;
-
-	rcu_read_lock();
-	classid = task_cls_state(current)->classid;
-	rcu_read_unlock();
+	if (softirq_count() != SOFTIRQ_OFFSET) {
+		/* If there is an sk_classid we'll use that. */
+		if (!skb->sk)
+			return -1;
+		classid = skb->sk->sk_classid;
+	}
 
 	if (!classid)
 		return -1;
@@ -289,18 +290,35 @@ static struct tcf_proto_ops cls_cgroup_ops __read_mostly = {
 
 static int __init init_cgroup_cls(void)
 {
-	int ret = register_tcf_proto_ops(&cls_cgroup_ops);
-	if (ret)
-		return ret;
+	int ret;
+
 	ret = cgroup_load_subsys(&net_cls_subsys);
 	if (ret)
-		unregister_tcf_proto_ops(&cls_cgroup_ops);
+		goto out;
+
+#ifndef CONFIG_NET_CLS_CGROUP
+	/* We can't use rcu_assign_pointer because this is an int. */
+	smp_wmb();
+	net_cls_subsys_id = net_cls_subsys.subsys_id;
+#endif
+
+	ret = register_tcf_proto_ops(&cls_cgroup_ops);
+	if (ret)
+		cgroup_unload_subsys(&net_cls_subsys);
+
+out:
 	return ret;
 }
 
 static void __exit exit_cgroup_cls(void)
 {
 	unregister_tcf_proto_ops(&cls_cgroup_ops);
+
+#ifndef CONFIG_NET_CLS_CGROUP
+	net_cls_subsys_id = -1;
+	synchronize_rcu();
+#endif
+
 	cgroup_unload_subsys(&net_cls_subsys);
 }
 
diff --git a/net/socket.c b/net/socket.c
index f9f7d0872cac..367d5477d00f 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -94,6 +94,7 @@
 
 #include <net/compat.h>
 #include <net/wext.h>
+#include <net/cls_cgroup.h>
 
 #include <net/sock.h>
 #include <linux/netfilter.h>
@@ -558,6 +559,8 @@ static inline int __sock_sendmsg(struct kiocb *iocb, struct socket *sock,
 	struct sock_iocb *si = kiocb_to_siocb(iocb);
 	int err;
 
+	sock_update_classid(sock->sk);
+
 	si->sock = sock;
 	si->scm = NULL;
 	si->msg = msg;
@@ -684,6 +687,8 @@ static inline int __sock_recvmsg_nosec(struct kiocb *iocb, struct socket *sock,
 {
 	struct sock_iocb *si = kiocb_to_siocb(iocb);
 
+	sock_update_classid(sock->sk);
+
 	si->sock = sock;
 	si->scm = NULL;
 	si->msg = msg;
@@ -777,6 +782,8 @@ static ssize_t sock_splice_read(struct file *file, loff_t *ppos,
 	if (unlikely(!sock->ops->splice_read))
 		return -EINVAL;
 
+	sock_update_classid(sock->sk);
+
 	return sock->ops->splice_read(sock, ppos, pipe, len, flags);
 }
 
@@ -3069,6 +3076,8 @@ int kernel_setsockopt(struct socket *sock, int level, int optname,
 int kernel_sendpage(struct socket *sock, struct page *page, int offset,
 		    size_t size, int flags)
 {
+	sock_update_classid(sock->sk);
+
 	if (sock->ops->sendpage)
 		return sock->ops->sendpage(sock, page, offset, size, flags);