summary refs log tree commit diff
path: root/net/packet/af_packet.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/packet/af_packet.c')
-rw-r--r--net/packet/af_packet.c438
1 files changed, 374 insertions, 64 deletions
diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c
index c0c3cda19712..c698cec0a445 100644
--- a/net/packet/af_packet.c
+++ b/net/packet/af_packet.c
@@ -187,9 +187,11 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg);
 
 static void packet_flush_mclist(struct sock *sk);
 
+struct packet_fanout;
 struct packet_sock {
 	/* struct sock has to be the first member of packet_sock */
 	struct sock		sk;
+	struct packet_fanout	*fanout;
 	struct tpacket_stats	stats;
 	struct packet_ring_buffer	rx_ring;
 	struct packet_ring_buffer	tx_ring;
@@ -212,6 +214,24 @@ struct packet_sock {
 	struct packet_type	prot_hook ____cacheline_aligned_in_smp;
 };
 
+#define PACKET_FANOUT_MAX	256
+
+struct packet_fanout {
+#ifdef CONFIG_NET_NS
+	struct net		*net;
+#endif
+	unsigned int		num_members;
+	u16			id;
+	u8			type;
+	u8			defrag;
+	atomic_t		rr_cur;
+	struct list_head	list;
+	struct sock		*arr[PACKET_FANOUT_MAX];
+	spinlock_t		lock;
+	atomic_t		sk_ref;
+	struct packet_type	prot_hook ____cacheline_aligned_in_smp;
+};
+
 struct packet_skb_cb {
 	unsigned int origlen;
 	union {
@@ -222,6 +242,64 @@ struct packet_skb_cb {
 
 #define PACKET_SKB_CB(__skb)	((struct packet_skb_cb *)((__skb)->cb))
 
+static inline struct packet_sock *pkt_sk(struct sock *sk)
+{
+	return (struct packet_sock *)sk;
+}
+
+static void __fanout_unlink(struct sock *sk, struct packet_sock *po);
+static void __fanout_link(struct sock *sk, struct packet_sock *po);
+
+/* register_prot_hook must be invoked with the po->bind_lock held,
+ * or from a context in which asynchronous accesses to the packet
+ * socket is not possible (packet_create()).
+ */
+static void register_prot_hook(struct sock *sk)
+{
+	struct packet_sock *po = pkt_sk(sk);
+	if (!po->running) {
+		if (po->fanout)
+			__fanout_link(sk, po);
+		else
+			dev_add_pack(&po->prot_hook);
+		sock_hold(sk);
+		po->running = 1;
+	}
+}
+
+/* {,__}unregister_prot_hook() must be invoked with the po->bind_lock
+ * held.   If the sync parameter is true, we will temporarily drop
+ * the po->bind_lock and do a synchronize_net to make sure no
+ * asynchronous packet processing paths still refer to the elements
+ * of po->prot_hook.  If the sync parameter is false, it is the
+ * callers responsibility to take care of this.
+ */
+static void __unregister_prot_hook(struct sock *sk, bool sync)
+{
+	struct packet_sock *po = pkt_sk(sk);
+
+	po->running = 0;
+	if (po->fanout)
+		__fanout_unlink(sk, po);
+	else
+		__dev_remove_pack(&po->prot_hook);
+	__sock_put(sk);
+
+	if (sync) {
+		spin_unlock(&po->bind_lock);
+		synchronize_net();
+		spin_lock(&po->bind_lock);
+	}
+}
+
+static void unregister_prot_hook(struct sock *sk, bool sync)
+{
+	struct packet_sock *po = pkt_sk(sk);
+
+	if (po->running)
+		__unregister_prot_hook(sk, sync);
+}
+
 static inline __pure struct page *pgv_to_page(void *addr)
 {
 	if (is_vmalloc_addr(addr))
@@ -324,11 +402,6 @@ static inline void packet_increment_head(struct packet_ring_buffer *buff)
 	buff->head = buff->head != buff->frame_max ? buff->head+1 : 0;
 }
 
-static inline struct packet_sock *pkt_sk(struct sock *sk)
-{
-	return (struct packet_sock *)sk;
-}
-
 static void packet_sock_destruct(struct sock *sk)
 {
 	skb_queue_purge(&sk->sk_error_queue);
@@ -344,6 +417,240 @@ static void packet_sock_destruct(struct sock *sk)
 	sk_refcnt_debug_dec(sk);
 }
 
+static int fanout_rr_next(struct packet_fanout *f, unsigned int num)
+{
+	int x = atomic_read(&f->rr_cur) + 1;
+
+	if (x >= num)
+		x = 0;
+
+	return x;
+}
+
+static struct sock *fanout_demux_hash(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+{
+	u32 idx, hash = skb->rxhash;
+
+	idx = ((u64)hash * num) >> 32;
+
+	return f->arr[idx];
+}
+
+static struct sock *fanout_demux_lb(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+{
+	int cur, old;
+
+	cur = atomic_read(&f->rr_cur);
+	while ((old = atomic_cmpxchg(&f->rr_cur, cur,
+				     fanout_rr_next(f, num))) != cur)
+		cur = old;
+	return f->arr[cur];
+}
+
+static struct sock *fanout_demux_cpu(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+{
+	unsigned int cpu = smp_processor_id();
+
+	return f->arr[cpu % num];
+}
+
+static struct sk_buff *fanout_check_defrag(struct sk_buff *skb)
+{
+#ifdef CONFIG_INET
+	const struct iphdr *iph;
+	u32 len;
+
+	if (skb->protocol != htons(ETH_P_IP))
+		return skb;
+
+	if (!pskb_may_pull(skb, sizeof(struct iphdr)))
+		return skb;
+
+	iph = ip_hdr(skb);
+	if (iph->ihl < 5 || iph->version != 4)
+		return skb;
+	if (!pskb_may_pull(skb, iph->ihl*4))
+		return skb;
+	iph = ip_hdr(skb);
+	len = ntohs(iph->tot_len);
+	if (skb->len < len || len < (iph->ihl * 4))
+		return skb;
+
+	if (ip_is_fragment(ip_hdr(skb))) {
+		skb = skb_share_check(skb, GFP_ATOMIC);
+		if (skb) {
+			if (pskb_trim_rcsum(skb, len))
+				return skb;
+			memset(IPCB(skb), 0, sizeof(struct inet_skb_parm));
+			if (ip_defrag(skb, IP_DEFRAG_AF_PACKET))
+				return NULL;
+			skb->rxhash = 0;
+		}
+	}
+#endif
+	return skb;
+}
+
+static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
+			     struct packet_type *pt, struct net_device *orig_dev)
+{
+	struct packet_fanout *f = pt->af_packet_priv;
+	unsigned int num = f->num_members;
+	struct packet_sock *po;
+	struct sock *sk;
+
+	if (!net_eq(dev_net(dev), read_pnet(&f->net)) ||
+	    !num) {
+		kfree_skb(skb);
+		return 0;
+	}
+
+	switch (f->type) {
+	case PACKET_FANOUT_HASH:
+	default:
+		if (f->defrag) {
+			skb = fanout_check_defrag(skb);
+			if (!skb)
+				return 0;
+		}
+		skb_get_rxhash(skb);
+		sk = fanout_demux_hash(f, skb, num);
+		break;
+	case PACKET_FANOUT_LB:
+		sk = fanout_demux_lb(f, skb, num);
+		break;
+	case PACKET_FANOUT_CPU:
+		sk = fanout_demux_cpu(f, skb, num);
+		break;
+	}
+
+	po = pkt_sk(sk);
+
+	return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev);
+}
+
+static DEFINE_MUTEX(fanout_mutex);
+static LIST_HEAD(fanout_list);
+
+static void __fanout_link(struct sock *sk, struct packet_sock *po)
+{
+	struct packet_fanout *f = po->fanout;
+
+	spin_lock(&f->lock);
+	f->arr[f->num_members] = sk;
+	smp_wmb();
+	f->num_members++;
+	spin_unlock(&f->lock);
+}
+
+static void __fanout_unlink(struct sock *sk, struct packet_sock *po)
+{
+	struct packet_fanout *f = po->fanout;
+	int i;
+
+	spin_lock(&f->lock);
+	for (i = 0; i < f->num_members; i++) {
+		if (f->arr[i] == sk)
+			break;
+	}
+	BUG_ON(i >= f->num_members);
+	f->arr[i] = f->arr[f->num_members - 1];
+	f->num_members--;
+	spin_unlock(&f->lock);
+}
+
+static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
+{
+	struct packet_sock *po = pkt_sk(sk);
+	struct packet_fanout *f, *match;
+	u8 type = type_flags & 0xff;
+	u8 defrag = (type_flags & PACKET_FANOUT_FLAG_DEFRAG) ? 1 : 0;
+	int err;
+
+	switch (type) {
+	case PACKET_FANOUT_HASH:
+	case PACKET_FANOUT_LB:
+	case PACKET_FANOUT_CPU:
+		break;
+	default:
+		return -EINVAL;
+	}
+
+	if (!po->running)
+		return -EINVAL;
+
+	if (po->fanout)
+		return -EALREADY;
+
+	mutex_lock(&fanout_mutex);
+	match = NULL;
+	list_for_each_entry(f, &fanout_list, list) {
+		if (f->id == id &&
+		    read_pnet(&f->net) == sock_net(sk)) {
+			match = f;
+			break;
+		}
+	}
+	err = -EINVAL;
+	if (match && match->defrag != defrag)
+		goto out;
+	if (!match) {
+		err = -ENOMEM;
+		match = kzalloc(sizeof(*match), GFP_KERNEL);
+		if (!match)
+			goto out;
+		write_pnet(&match->net, sock_net(sk));
+		match->id = id;
+		match->type = type;
+		match->defrag = defrag;
+		atomic_set(&match->rr_cur, 0);
+		INIT_LIST_HEAD(&match->list);
+		spin_lock_init(&match->lock);
+		atomic_set(&match->sk_ref, 0);
+		match->prot_hook.type = po->prot_hook.type;
+		match->prot_hook.dev = po->prot_hook.dev;
+		match->prot_hook.func = packet_rcv_fanout;
+		match->prot_hook.af_packet_priv = match;
+		dev_add_pack(&match->prot_hook);
+		list_add(&match->list, &fanout_list);
+	}
+	err = -EINVAL;
+	if (match->type == type &&
+	    match->prot_hook.type == po->prot_hook.type &&
+	    match->prot_hook.dev == po->prot_hook.dev) {
+		err = -ENOSPC;
+		if (atomic_read(&match->sk_ref) < PACKET_FANOUT_MAX) {
+			__dev_remove_pack(&po->prot_hook);
+			po->fanout = match;
+			atomic_inc(&match->sk_ref);
+			__fanout_link(sk, po);
+			err = 0;
+		}
+	}
+out:
+	mutex_unlock(&fanout_mutex);
+	return err;
+}
+
+static void fanout_release(struct sock *sk)
+{
+	struct packet_sock *po = pkt_sk(sk);
+	struct packet_fanout *f;
+
+	f = po->fanout;
+	if (!f)
+		return;
+
+	po->fanout = NULL;
+
+	mutex_lock(&fanout_mutex);
+	if (atomic_dec_and_test(&f->sk_ref)) {
+		list_del(&f->list);
+		dev_remove_pack(&f->prot_hook);
+		kfree(f);
+	}
+	mutex_unlock(&fanout_mutex);
+}
 
 static const struct proto_ops packet_ops;
 
@@ -822,7 +1129,6 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
 	else
 		sll->sll_ifindex = dev->ifindex;
 
-	__packet_set_status(po, h.raw, status);
 	smp_mb();
 #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1
 	{
@@ -831,8 +1137,10 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
 		end = (u8 *)PAGE_ALIGN((unsigned long)h.raw + macoff + snaplen);
 		for (start = h.raw; start < end; start += PAGE_SIZE)
 			flush_dcache_page(pgv_to_page(start));
+		smp_wmb();
 	}
 #endif
+	__packet_set_status(po, h.raw, status);
 
 	sk->sk_data_ready(sk, 0);
 
@@ -975,7 +1283,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 	struct sk_buff *skb;
 	struct net_device *dev;
 	__be16 proto;
-	int ifindex, err, reserve = 0;
+	bool need_rls_dev = false;
+	int err, reserve = 0;
 	void *ph;
 	struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name;
 	int tp_len, size_max;
@@ -987,7 +1296,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 
 	err = -EBUSY;
 	if (saddr == NULL) {
-		ifindex	= po->ifindex;
+		dev = po->prot_hook.dev;
 		proto	= po->num;
 		addr	= NULL;
 	} else {
@@ -998,12 +1307,12 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 					+ offsetof(struct sockaddr_ll,
 						sll_addr)))
 			goto out;
-		ifindex	= saddr->sll_ifindex;
 		proto	= saddr->sll_protocol;
 		addr	= saddr->sll_addr;
+		dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex);
+		need_rls_dev = true;
 	}
 
-	dev = dev_get_by_index(sock_net(&po->sk), ifindex);
 	err = -ENXIO;
 	if (unlikely(dev == NULL))
 		goto out;
@@ -1089,7 +1398,8 @@ out_status:
 	__packet_set_status(po, ph, status);
 	kfree_skb(skb);
 out_put:
-	dev_put(dev);
+	if (need_rls_dev)
+		dev_put(dev);
 out:
 	mutex_unlock(&po->pg_vec_lock);
 	return err;
@@ -1127,8 +1437,9 @@ static int packet_snd(struct socket *sock,
 	struct sk_buff *skb;
 	struct net_device *dev;
 	__be16 proto;
+	bool need_rls_dev = false;
 	unsigned char *addr;
-	int ifindex, err, reserve = 0;
+	int err, reserve = 0;
 	struct virtio_net_hdr vnet_hdr = { 0 };
 	int offset = 0;
 	int vnet_hdr_len;
@@ -1140,7 +1451,7 @@ static int packet_snd(struct socket *sock,
 	 */
 
 	if (saddr == NULL) {
-		ifindex	= po->ifindex;
+		dev = po->prot_hook.dev;
 		proto	= po->num;
 		addr	= NULL;
 	} else {
@@ -1149,13 +1460,12 @@ static int packet_snd(struct socket *sock,
 			goto out;
 		if (msg->msg_namelen < (saddr->sll_halen + offsetof(struct sockaddr_ll, sll_addr)))
 			goto out;
-		ifindex	= saddr->sll_ifindex;
 		proto	= saddr->sll_protocol;
 		addr	= saddr->sll_addr;
+		dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex);
+		need_rls_dev = true;
 	}
 
-
-	dev = dev_get_by_index(sock_net(sk), ifindex);
 	err = -ENXIO;
 	if (dev == NULL)
 		goto out_unlock;
@@ -1286,14 +1596,15 @@ static int packet_snd(struct socket *sock,
 	if (err > 0 && (err = net_xmit_errno(err)) != 0)
 		goto out_unlock;
 
-	dev_put(dev);
+	if (need_rls_dev)
+		dev_put(dev);
 
 	return len;
 
 out_free:
 	kfree_skb(skb);
 out_unlock:
-	if (dev)
+	if (dev && need_rls_dev)
 		dev_put(dev);
 out:
 	return err;
@@ -1334,14 +1645,10 @@ static int packet_release(struct socket *sock)
 	spin_unlock_bh(&net->packet.sklist_lock);
 
 	spin_lock(&po->bind_lock);
-	if (po->running) {
-		/*
-		 * Remove from protocol table
-		 */
-		po->running = 0;
-		po->num = 0;
-		__dev_remove_pack(&po->prot_hook);
-		__sock_put(sk);
+	unregister_prot_hook(sk, false);
+	if (po->prot_hook.dev) {
+		dev_put(po->prot_hook.dev);
+		po->prot_hook.dev = NULL;
 	}
 	spin_unlock(&po->bind_lock);
 
@@ -1355,6 +1662,8 @@ static int packet_release(struct socket *sock)
 	if (po->tx_ring.pg_vec)
 		packet_set_ring(sk, &req, 1, 1);
 
+	fanout_release(sk);
+
 	synchronize_net();
 	/*
 	 *	Now the socket is dead. No more input will appear.
@@ -1378,24 +1687,18 @@ static int packet_release(struct socket *sock)
 static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protocol)
 {
 	struct packet_sock *po = pkt_sk(sk);
-	/*
-	 *	Detach an existing hook if present.
-	 */
+
+	if (po->fanout)
+		return -EINVAL;
 
 	lock_sock(sk);
 
 	spin_lock(&po->bind_lock);
-	if (po->running) {
-		__sock_put(sk);
-		po->running = 0;
-		po->num = 0;
-		spin_unlock(&po->bind_lock);
-		dev_remove_pack(&po->prot_hook);
-		spin_lock(&po->bind_lock);
-	}
-
+	unregister_prot_hook(sk, true);
 	po->num = protocol;
 	po->prot_hook.type = protocol;
+	if (po->prot_hook.dev)
+		dev_put(po->prot_hook.dev);
 	po->prot_hook.dev = dev;
 
 	po->ifindex = dev ? dev->ifindex : 0;
@@ -1404,9 +1707,7 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
 		goto out_unlock;
 
 	if (!dev || (dev->flags & IFF_UP)) {
-		dev_add_pack(&po->prot_hook);
-		sock_hold(sk);
-		po->running = 1;
+		register_prot_hook(sk);
 	} else {
 		sk->sk_err = ENETDOWN;
 		if (!sock_flag(sk, SOCK_DEAD))
@@ -1440,10 +1741,8 @@ static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
 	strlcpy(name, uaddr->sa_data, sizeof(name));
 
 	dev = dev_get_by_name(sock_net(sk), name);
-	if (dev) {
+	if (dev)
 		err = packet_do_bind(sk, dev, pkt_sk(sk)->num);
-		dev_put(dev);
-	}
 	return err;
 }
 
@@ -1471,8 +1770,6 @@ static int packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len
 			goto out;
 	}
 	err = packet_do_bind(sk, dev, sll->sll_protocol ? : pkt_sk(sk)->num);
-	if (dev)
-		dev_put(dev);
 
 out:
 	return err;
@@ -1537,9 +1834,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
 
 	if (proto) {
 		po->prot_hook.type = proto;
-		dev_add_pack(&po->prot_hook);
-		sock_hold(sk);
-		po->running = 1;
+		register_prot_hook(sk);
 	}
 
 	spin_lock_bh(&net->packet.sklist_lock);
@@ -1681,6 +1976,8 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
 			vnet_hdr.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;
 			vnet_hdr.csum_start = skb_checksum_start_offset(skb);
 			vnet_hdr.csum_offset = skb->csum_offset;
+		} else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
+			vnet_hdr.flags = VIRTIO_NET_HDR_F_DATA_VALID;
 		} /* else everything is zero */
 
 		err = memcpy_toiovec(msg->msg_iov, (void *)&vnet_hdr,
@@ -2102,6 +2399,17 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv
 		po->tp_tstamp = val;
 		return 0;
 	}
+	case PACKET_FANOUT:
+	{
+		int val;
+
+		if (optlen != sizeof(val))
+			return -EINVAL;
+		if (copy_from_user(&val, optval, sizeof(val)))
+			return -EFAULT;
+
+		return fanout_add(sk, val & 0xffff, val >> 16);
+	}
 	default:
 		return -ENOPROTOOPT;
 	}
@@ -2200,6 +2508,15 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
 		val = po->tp_tstamp;
 		data = &val;
 		break;
+	case PACKET_FANOUT:
+		if (len > sizeof(int))
+			len = sizeof(int);
+		val = (po->fanout ?
+		       ((u32)po->fanout->id |
+			((u32)po->fanout->type << 16)) :
+		       0);
+		data = &val;
+		break;
 	default:
 		return -ENOPROTOOPT;
 	}
@@ -2233,15 +2550,15 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
 			if (dev->ifindex == po->ifindex) {
 				spin_lock(&po->bind_lock);
 				if (po->running) {
-					__dev_remove_pack(&po->prot_hook);
-					__sock_put(sk);
-					po->running = 0;
+					__unregister_prot_hook(sk, false);
 					sk->sk_err = ENETDOWN;
 					if (!sock_flag(sk, SOCK_DEAD))
 						sk->sk_error_report(sk);
 				}
 				if (msg == NETDEV_UNREGISTER) {
 					po->ifindex = -1;
+					if (po->prot_hook.dev)
+						dev_put(po->prot_hook.dev);
 					po->prot_hook.dev = NULL;
 				}
 				spin_unlock(&po->bind_lock);
@@ -2250,11 +2567,8 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
 		case NETDEV_UP:
 			if (dev->ifindex == po->ifindex) {
 				spin_lock(&po->bind_lock);
-				if (po->num && !po->running) {
-					dev_add_pack(&po->prot_hook);
-					sock_hold(sk);
-					po->running = 1;
-				}
+				if (po->num)
+					register_prot_hook(sk);
 				spin_unlock(&po->bind_lock);
 			}
 			break;
@@ -2521,10 +2835,8 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,
 	was_running = po->running;
 	num = po->num;
 	if (was_running) {
-		__dev_remove_pack(&po->prot_hook);
 		po->num = 0;
-		po->running = 0;
-		__sock_put(sk);
+		__unregister_prot_hook(sk, false);
 	}
 	spin_unlock(&po->bind_lock);
 
@@ -2555,11 +2867,9 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,
 	mutex_unlock(&po->pg_vec_lock);
 
 	spin_lock(&po->bind_lock);
-	if (was_running && !po->running) {
-		sock_hold(sk);
-		po->running = 1;
+	if (was_running) {
 		po->num = num;
-		dev_add_pack(&po->prot_hook);
+		register_prot_hook(sk);
 	}
 	spin_unlock(&po->bind_lock);