summary refs log tree commit diff
path: root/net/ipv4/fou.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/ipv4/fou.c')
-rw-r--r--net/ipv4/fou.c144
1 files changed, 88 insertions, 56 deletions
diff --git a/net/ipv4/fou.c b/net/ipv4/fou.c
index eeec7d60e5fd..5f9207c039e7 100644
--- a/net/ipv4/fou.c
+++ b/net/ipv4/fou.c
@@ -21,6 +21,7 @@ struct fou {
 	u8 protocol;
 	u8 flags;
 	__be16 port;
+	u8 family;
 	u16 type;
 	struct list_head list;
 	struct rcu_head rcu;
@@ -47,14 +48,17 @@ static inline struct fou *fou_from_sock(struct sock *sk)
 	return sk->sk_user_data;
 }
 
-static int fou_recv_pull(struct sk_buff *skb, size_t len)
+static int fou_recv_pull(struct sk_buff *skb, struct fou *fou, size_t len)
 {
-	struct iphdr *iph = ip_hdr(skb);
-
 	/* Remove 'len' bytes from the packet (UDP header and
 	 * FOU header if present).
 	 */
-	iph->tot_len = htons(ntohs(iph->tot_len) - len);
+	if (fou->family == AF_INET)
+		ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
+	else
+		ipv6_hdr(skb)->payload_len =
+		    htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
+
 	__skb_pull(skb, len);
 	skb_postpull_rcsum(skb, udp_hdr(skb), len);
 	skb_reset_transport_header(skb);
@@ -68,7 +72,7 @@ static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
 	if (!fou)
 		return 1;
 
-	if (fou_recv_pull(skb, sizeof(struct udphdr)))
+	if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
 		goto drop;
 
 	return -fou->protocol;
@@ -141,7 +145,11 @@ static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
 
 	hdrlen = sizeof(struct guehdr) + optlen;
 
-	ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
+	if (fou->family == AF_INET)
+		ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
+	else
+		ipv6_hdr(skb)->payload_len =
+		    htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
 
 	/* Pull csum through the guehdr now . This can be used if
 	 * there is a remote checksum offload.
@@ -426,7 +434,8 @@ static int fou_add_to_port_list(struct net *net, struct fou *fou)
 
 	mutex_lock(&fn->fou_lock);
 	list_for_each_entry(fout, &fn->fou_list, list) {
-		if (fou->port == fout->port) {
+		if (fou->port == fout->port &&
+		    fou->family == fout->family) {
 			mutex_unlock(&fn->fou_lock);
 			return -EALREADY;
 		}
@@ -448,31 +457,13 @@ static void fou_release(struct fou *fou)
 	kfree_rcu(fou, rcu);
 }
 
-static int fou_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
-{
-	udp_sk(sk)->encap_rcv = fou_udp_recv;
-	udp_sk(sk)->gro_receive = fou_gro_receive;
-	udp_sk(sk)->gro_complete = fou_gro_complete;
-	fou_from_sock(sk)->protocol = cfg->protocol;
-
-	return 0;
-}
-
-static int gue_encap_init(struct sock *sk, struct fou *fou, struct fou_cfg *cfg)
-{
-	udp_sk(sk)->encap_rcv = gue_udp_recv;
-	udp_sk(sk)->gro_receive = gue_gro_receive;
-	udp_sk(sk)->gro_complete = gue_gro_complete;
-
-	return 0;
-}
-
 static int fou_create(struct net *net, struct fou_cfg *cfg,
 		      struct socket **sockp)
 {
 	struct socket *sock = NULL;
 	struct fou *fou = NULL;
 	struct sock *sk;
+	struct udp_tunnel_sock_cfg tunnel_cfg;
 	int err;
 
 	/* Open UDP socket */
@@ -489,35 +480,36 @@ static int fou_create(struct net *net, struct fou_cfg *cfg,
 
 	sk = sock->sk;
 
-	fou->flags = cfg->flags;
 	fou->port = cfg->udp_config.local_udp_port;
+	fou->family = cfg->udp_config.family;
+	fou->flags = cfg->flags;
+	fou->type = cfg->type;
+	fou->sock = sock;
+
+	memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
+	tunnel_cfg.encap_type = 1;
+	tunnel_cfg.sk_user_data = fou;
+	tunnel_cfg.encap_destroy = NULL;
 
 	/* Initial for fou type */
 	switch (cfg->type) {
 	case FOU_ENCAP_DIRECT:
-		err = fou_encap_init(sk, fou, cfg);
-		if (err)
-			goto error;
+		tunnel_cfg.encap_rcv = fou_udp_recv;
+		tunnel_cfg.gro_receive = fou_gro_receive;
+		tunnel_cfg.gro_complete = fou_gro_complete;
+		fou->protocol = cfg->protocol;
 		break;
 	case FOU_ENCAP_GUE:
-		err = gue_encap_init(sk, fou, cfg);
-		if (err)
-			goto error;
+		tunnel_cfg.encap_rcv = gue_udp_recv;
+		tunnel_cfg.gro_receive = gue_gro_receive;
+		tunnel_cfg.gro_complete = gue_gro_complete;
 		break;
 	default:
 		err = -EINVAL;
 		goto error;
 	}
 
-	fou->type = cfg->type;
-
-	udp_sk(sk)->encap_type = 1;
-	udp_encap_enable();
-
-	sk->sk_user_data = fou;
-	fou->sock = sock;
-
-	inet_inc_convert_csum(sk);
+	setup_udp_tunnel_sock(net, sock, &tunnel_cfg);
 
 	sk->sk_allocation = GFP_ATOMIC;
 
@@ -542,12 +534,13 @@ static int fou_destroy(struct net *net, struct fou_cfg *cfg)
 {
 	struct fou_net *fn = net_generic(net, fou_net_id);
 	__be16 port = cfg->udp_config.local_udp_port;
+	u8 family = cfg->udp_config.family;
 	int err = -EINVAL;
 	struct fou *fou;
 
 	mutex_lock(&fn->fou_lock);
 	list_for_each_entry(fou, &fn->fou_list, list) {
-		if (fou->port == port) {
+		if (fou->port == port && fou->family == family) {
 			fou_release(fou);
 			err = 0;
 			break;
@@ -585,8 +578,15 @@ static int parse_nl_config(struct genl_info *info,
 	if (info->attrs[FOU_ATTR_AF]) {
 		u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
 
-		if (family != AF_INET)
-			return -EINVAL;
+		switch (family) {
+		case AF_INET:
+			break;
+		case AF_INET6:
+			cfg->udp_config.ipv6_v6only = 1;
+			break;
+		default:
+			return -EAFNOSUPPORT;
+		}
 
 		cfg->udp_config.family = family;
 	}
@@ -677,6 +677,7 @@ static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
 	struct fou_cfg cfg;
 	struct fou *fout;
 	__be16 port;
+	u8 family;
 	int ret;
 
 	ret = parse_nl_config(info, &cfg);
@@ -686,6 +687,10 @@ static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
 	if (port == 0)
 		return -EINVAL;
 
+	family = cfg.udp_config.family;
+	if (family != AF_INET && family != AF_INET6)
+		return -EINVAL;
+
 	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
 	if (!msg)
 		return -ENOMEM;
@@ -693,7 +698,7 @@ static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
 	ret = -ESRCH;
 	mutex_lock(&fn->fou_lock);
 	list_for_each_entry(fout, &fn->fou_list, list) {
-		if (port == fout->port) {
+		if (port == fout->port && family == fout->family) {
 			ret = fou_dump_info(fout, info->snd_portid,
 					    info->snd_seq, 0, msg,
 					    info->genlhdr->cmd);
@@ -798,6 +803,22 @@ static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
 	*protocol = IPPROTO_UDP;
 }
 
+int __fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
+		       u8 *protocol, __be16 *sport, int type)
+{
+	int err;
+
+	err = iptunnel_handle_offloads(skb, type);
+	if (err)
+		return err;
+
+	*sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
+						skb, 0, 0, false);
+
+	return 0;
+}
+EXPORT_SYMBOL(__fou_build_header);
+
 int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
 		     u8 *protocol, struct flowi4 *fl4)
 {
@@ -806,26 +827,21 @@ int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
 	__be16 sport;
 	int err;
 
-	err = iptunnel_handle_offloads(skb, type);
+	err = __fou_build_header(skb, e, protocol, &sport, type);
 	if (err)
 		return err;
 
-	sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
-					       skb, 0, 0, false);
 	fou_build_udp(skb, e, fl4, protocol, sport);
 
 	return 0;
 }
 EXPORT_SYMBOL(fou_build_header);
 
-int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
-		     u8 *protocol, struct flowi4 *fl4)
+int __gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
+		       u8 *protocol, __be16 *sport, int type)
 {
-	int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
-						       SKB_GSO_UDP_TUNNEL;
 	struct guehdr *guehdr;
 	size_t hdrlen, optlen = 0;
-	__be16 sport;
 	void *data;
 	bool need_priv = false;
 	int err;
@@ -844,8 +860,8 @@ int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
 		return err;
 
 	/* Get source port (based on flow hash) before skb_push */
-	sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
-					       skb, 0, 0, false);
+	*sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
+						skb, 0, 0, false);
 
 	hdrlen = sizeof(struct guehdr) + optlen;
 
@@ -890,6 +906,22 @@ int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
 
 	}
 
+	return 0;
+}
+EXPORT_SYMBOL(__gue_build_header);
+
+int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
+		     u8 *protocol, struct flowi4 *fl4)
+{
+	int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
+						       SKB_GSO_UDP_TUNNEL;
+	__be16 sport;
+	int err;
+
+	err = __gue_build_header(skb, e, protocol, &sport, type);
+	if (err)
+		return err;
+
 	fou_build_udp(skb, e, fl4, protocol, sport);
 
 	return 0;