summary refs log tree commit diff
path: root/drivers/infiniband/sw/rxe/rxe_net.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/infiniband/sw/rxe/rxe_net.c')
-rw-r--r--drivers/infiniband/sw/rxe/rxe_net.c56
1 files changed, 49 insertions, 7 deletions
diff --git a/drivers/infiniband/sw/rxe/rxe_net.c b/drivers/infiniband/sw/rxe/rxe_net.c
index 159246b03867..9da6e37fb70c 100644
--- a/drivers/infiniband/sw/rxe/rxe_net.c
+++ b/drivers/infiniband/sw/rxe/rxe_net.c
@@ -182,11 +182,39 @@ static struct dst_entry *rxe_find_route6(struct net_device *ndev,
 
 #endif
 
+/*
+ * Derive the net_device from the av.
+ * For physical devices, this will just return rxe->ndev.
+ * But for VLAN devices, it will return the vlan dev.
+ * Caller should dev_put() the returned net_device.
+ */
+static struct net_device *rxe_netdev_from_av(struct rxe_dev *rxe,
+					     int port_num,
+					     struct rxe_av *av)
+{
+	union ib_gid gid;
+	struct ib_gid_attr attr;
+	struct net_device *ndev = rxe->ndev;
+
+	if (ib_get_cached_gid(&rxe->ib_dev, port_num, av->grh.sgid_index,
+			      &gid, &attr) == 0 &&
+	    attr.ndev && attr.ndev != ndev)
+		ndev = attr.ndev;
+	else
+		/* Only to ensure that caller may call dev_put() */
+		dev_hold(ndev);
+
+	return ndev;
+}
+
 static struct dst_entry *rxe_find_route(struct rxe_dev *rxe,
 					struct rxe_qp *qp,
 					struct rxe_av *av)
 {
 	struct dst_entry *dst = NULL;
+	struct net_device *ndev;
+
+	ndev = rxe_netdev_from_av(rxe, qp->attr.port_num, av);
 
 	if (qp_type(qp) == IB_QPT_RC)
 		dst = sk_dst_get(qp->sk->sk);
@@ -201,14 +229,14 @@ static struct dst_entry *rxe_find_route(struct rxe_dev *rxe,
 
 			saddr = &av->sgid_addr._sockaddr_in.sin_addr;
 			daddr = &av->dgid_addr._sockaddr_in.sin_addr;
-			dst = rxe_find_route4(rxe->ndev, saddr, daddr);
+			dst = rxe_find_route4(ndev, saddr, daddr);
 		} else if (av->network_type == RDMA_NETWORK_IPV6) {
 			struct in6_addr *saddr6;
 			struct in6_addr *daddr6;
 
 			saddr6 = &av->sgid_addr._sockaddr_in6.sin6_addr;
 			daddr6 = &av->dgid_addr._sockaddr_in6.sin6_addr;
-			dst = rxe_find_route6(rxe->ndev, saddr6, daddr6);
+			dst = rxe_find_route6(ndev, saddr6, daddr6);
 #if IS_ENABLED(CONFIG_IPV6)
 			if (dst)
 				qp->dst_cookie =
@@ -217,6 +245,7 @@ static struct dst_entry *rxe_find_route(struct rxe_dev *rxe,
 		}
 	}
 
+	dev_put(ndev);
 	return dst;
 }
 
@@ -224,9 +253,14 @@ static int rxe_udp_encap_recv(struct sock *sk, struct sk_buff *skb)
 {
 	struct udphdr *udph;
 	struct net_device *ndev = skb->dev;
+	struct net_device *rdev = ndev;
 	struct rxe_dev *rxe = net_to_rxe(ndev);
 	struct rxe_pkt_info *pkt = SKB_TO_PKT(skb);
 
+	if (!rxe && is_vlan_dev(rdev)) {
+		rdev = vlan_dev_real_dev(ndev);
+		rxe = net_to_rxe(rdev);
+	}
 	if (!rxe)
 		goto drop;
 
@@ -450,7 +484,7 @@ static void rxe_skb_tx_dtor(struct sk_buff *skb)
 	rxe_drop_ref(qp);
 }
 
-int rxe_send(struct rxe_dev *rxe, struct rxe_pkt_info *pkt, struct sk_buff *skb)
+int rxe_send(struct rxe_pkt_info *pkt, struct sk_buff *skb)
 {
 	struct rxe_av *av;
 	int err;
@@ -498,6 +532,10 @@ struct sk_buff *rxe_init_packet(struct rxe_dev *rxe, struct rxe_av *av,
 {
 	unsigned int hdr_len;
 	struct sk_buff *skb;
+	struct net_device *ndev;
+	const int port_num = 1;
+
+	ndev = rxe_netdev_from_av(rxe, port_num, av);
 
 	if (av->network_type == RDMA_NETWORK_IPV4)
 		hdr_len = ETH_HLEN + sizeof(struct udphdr) +
@@ -506,26 +544,30 @@ struct sk_buff *rxe_init_packet(struct rxe_dev *rxe, struct rxe_av *av,
 		hdr_len = ETH_HLEN + sizeof(struct udphdr) +
 			sizeof(struct ipv6hdr);
 
-	skb = alloc_skb(paylen + hdr_len + LL_RESERVED_SPACE(rxe->ndev),
+	skb = alloc_skb(paylen + hdr_len + LL_RESERVED_SPACE(ndev),
 			GFP_ATOMIC);
-	if (unlikely(!skb))
+
+	if (unlikely(!skb)) {
+		dev_put(ndev);
 		return NULL;
+	}
 
 	skb_reserve(skb, hdr_len + LL_RESERVED_SPACE(rxe->ndev));
 
-	skb->dev	= rxe->ndev;
+	skb->dev	= ndev;
 	if (av->network_type == RDMA_NETWORK_IPV4)
 		skb->protocol = htons(ETH_P_IP);
 	else
 		skb->protocol = htons(ETH_P_IPV6);
 
 	pkt->rxe	= rxe;
-	pkt->port_num	= 1;
+	pkt->port_num	= port_num;
 	pkt->hdr	= skb_put(skb, paylen);
 	pkt->mask	|= RXE_GRH_MASK;
 
 	memset(pkt->hdr, 0, paylen);
 
+	dev_put(ndev);
 	return skb;
 }