summary refs log tree commit diff
path: root/drivers/infiniband/sw/rxe/rxe_verbs.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/infiniband/sw/rxe/rxe_verbs.c')
-rw-r--r--drivers/infiniband/sw/rxe/rxe_verbs.c101
1 files changed, 50 insertions, 51 deletions
diff --git a/drivers/infiniband/sw/rxe/rxe_verbs.c b/drivers/infiniband/sw/rxe/rxe_verbs.c
index f4bab2cd0ec2..2cb52fd48cf1 100644
--- a/drivers/infiniband/sw/rxe/rxe_verbs.c
+++ b/drivers/infiniband/sw/rxe/rxe_verbs.c
@@ -77,40 +77,6 @@ out:
 	return rc;
 }
 
-static int rxe_query_gid(struct ib_device *device,
-			 u8 port_num, int index, union ib_gid *gid)
-{
-	int ret;
-
-	if (index > RXE_PORT_GID_TBL_LEN)
-		return -EINVAL;
-
-	ret = ib_get_cached_gid(device, port_num, index, gid, NULL);
-	if (ret == -EAGAIN) {
-		memcpy(gid, &zgid, sizeof(*gid));
-		return 0;
-	}
-
-	return ret;
-}
-
-static int rxe_add_gid(struct ib_device *device, u8 port_num, unsigned int
-		       index, const union ib_gid *gid,
-		       const struct ib_gid_attr *attr, void **context)
-{
-	if (index >= RXE_PORT_GID_TBL_LEN)
-		return -EINVAL;
-	return 0;
-}
-
-static int rxe_del_gid(struct ib_device *device, u8 port_num, unsigned int
-		       index, void **context)
-{
-	if (index >= RXE_PORT_GID_TBL_LEN)
-		return -EINVAL;
-	return 0;
-}
-
 static struct net_device *rxe_get_netdev(struct ib_device *device,
 					 u8 port_num)
 {
@@ -273,9 +239,7 @@ static int rxe_init_av(struct rxe_dev *rxe, struct rdma_ah_attr *attr,
 
 	rxe_av_from_attr(rdma_ah_get_port_num(attr), av, attr);
 	rxe_av_fill_ip_info(av, attr, &sgid_attr, &sgid);
-
-	if (sgid_attr.ndev)
-		dev_put(sgid_attr.ndev);
+	dev_put(sgid_attr.ndev);
 	return 0;
 }
 
@@ -407,6 +371,13 @@ static struct ib_srq *rxe_create_srq(struct ib_pd *ibpd,
 	struct rxe_pd *pd = to_rpd(ibpd);
 	struct rxe_srq *srq;
 	struct ib_ucontext *context = udata ? ibpd->uobject->context : NULL;
+	struct rxe_create_srq_resp __user *uresp = NULL;
+
+	if (udata) {
+		if (udata->outlen < sizeof(*uresp))
+			return ERR_PTR(-EINVAL);
+		uresp = udata->outbuf;
+	}
 
 	err = rxe_srq_chk_attr(rxe, NULL, &init->attr, IB_SRQ_INIT_MASK);
 	if (err)
@@ -422,7 +393,7 @@ static struct ib_srq *rxe_create_srq(struct ib_pd *ibpd,
 	rxe_add_ref(pd);
 	srq->pd = pd;
 
-	err = rxe_srq_from_init(rxe, srq, init, context, udata);
+	err = rxe_srq_from_init(rxe, srq, init, context, uresp);
 	if (err)
 		goto err2;
 
@@ -443,12 +414,22 @@ static int rxe_modify_srq(struct ib_srq *ibsrq, struct ib_srq_attr *attr,
 	int err;
 	struct rxe_srq *srq = to_rsrq(ibsrq);
 	struct rxe_dev *rxe = to_rdev(ibsrq->device);
+	struct rxe_modify_srq_cmd ucmd = {};
+
+	if (udata) {
+		if (udata->inlen < sizeof(ucmd))
+			return -EINVAL;
+
+		err = ib_copy_from_udata(&ucmd, udata, sizeof(ucmd));
+		if (err)
+			return err;
+	}
 
 	err = rxe_srq_chk_attr(rxe, srq, attr, mask);
 	if (err)
 		goto err1;
 
-	err = rxe_srq_from_attr(rxe, srq, attr, mask, udata);
+	err = rxe_srq_from_attr(rxe, srq, attr, mask, &ucmd);
 	if (err)
 		goto err1;
 
@@ -517,6 +498,13 @@ static struct ib_qp *rxe_create_qp(struct ib_pd *ibpd,
 	struct rxe_dev *rxe = to_rdev(ibpd->device);
 	struct rxe_pd *pd = to_rpd(ibpd);
 	struct rxe_qp *qp;
+	struct rxe_create_qp_resp __user *uresp = NULL;
+
+	if (udata) {
+		if (udata->outlen < sizeof(*uresp))
+			return ERR_PTR(-EINVAL);
+		uresp = udata->outbuf;
+	}
 
 	err = rxe_qp_chk_init(rxe, init);
 	if (err)
@@ -538,7 +526,7 @@ static struct ib_qp *rxe_create_qp(struct ib_pd *ibpd,
 
 	rxe_add_index(qp);
 
-	err = rxe_qp_from_init(rxe, qp, pd, init, udata, ibpd);
+	err = rxe_qp_from_init(rxe, qp, pd, init, uresp, ibpd);
 	if (err)
 		goto err3;
 
@@ -711,9 +699,8 @@ static int init_send_wqe(struct rxe_qp *qp, struct ib_send_wr *ibwr,
 		memcpy(wqe->dma.sge, ibwr->sg_list,
 		       num_sge * sizeof(struct ib_sge));
 
-	wqe->iova		= (mask & WR_ATOMIC_MASK) ?
-					atomic_wr(ibwr)->remote_addr :
-					rdma_wr(ibwr)->remote_addr;
+	wqe->iova = mask & WR_ATOMIC_MASK ? atomic_wr(ibwr)->remote_addr :
+		mask & WR_READ_OR_WRITE_MASK ? rdma_wr(ibwr)->remote_addr : 0;
 	wqe->mask		= mask;
 	wqe->dma.length		= length;
 	wqe->dma.resid		= length;
@@ -889,11 +876,18 @@ static struct ib_cq *rxe_create_cq(struct ib_device *dev,
 	int err;
 	struct rxe_dev *rxe = to_rdev(dev);
 	struct rxe_cq *cq;
+	struct rxe_create_cq_resp __user *uresp = NULL;
+
+	if (udata) {
+		if (udata->outlen < sizeof(*uresp))
+			return ERR_PTR(-EINVAL);
+		uresp = udata->outbuf;
+	}
 
 	if (attr->flags)
 		return ERR_PTR(-EINVAL);
 
-	err = rxe_cq_chk_attr(rxe, NULL, attr->cqe, attr->comp_vector, udata);
+	err = rxe_cq_chk_attr(rxe, NULL, attr->cqe, attr->comp_vector);
 	if (err)
 		goto err1;
 
@@ -904,7 +898,7 @@ static struct ib_cq *rxe_create_cq(struct ib_device *dev,
 	}
 
 	err = rxe_cq_from_init(rxe, cq, attr->cqe, attr->comp_vector,
-			       context, udata);
+			       context, uresp);
 	if (err)
 		goto err2;
 
@@ -931,12 +925,19 @@ static int rxe_resize_cq(struct ib_cq *ibcq, int cqe, struct ib_udata *udata)
 	int err;
 	struct rxe_cq *cq = to_rcq(ibcq);
 	struct rxe_dev *rxe = to_rdev(ibcq->device);
+	struct rxe_resize_cq_resp __user *uresp = NULL;
+
+	if (udata) {
+		if (udata->outlen < sizeof(*uresp))
+			return -EINVAL;
+		uresp = udata->outbuf;
+	}
 
-	err = rxe_cq_chk_attr(rxe, cq, cqe, 0, udata);
+	err = rxe_cq_chk_attr(rxe, cq, cqe, 0);
 	if (err)
 		goto err1;
 
-	err = rxe_cq_resize_queue(cq, cqe, udata);
+	err = rxe_cq_resize_queue(cq, cqe, uresp);
 	if (err)
 		goto err1;
 
@@ -1207,7 +1208,7 @@ int rxe_register_device(struct rxe_dev *rxe)
 			    rxe->ndev->dev_addr);
 	dev->dev.dma_ops = &dma_virt_ops;
 	dma_coerce_mask_and_coherent(&dev->dev,
-				     dma_get_required_mask(dev->dev.parent));
+				     dma_get_required_mask(&dev->dev));
 
 	dev->uverbs_abi_ver = RXE_UVERBS_ABI_VERSION;
 	dev->uverbs_cmd_mask = BIT_ULL(IB_USER_VERBS_CMD_GET_CONTEXT)
@@ -1248,10 +1249,7 @@ int rxe_register_device(struct rxe_dev *rxe)
 	dev->query_port = rxe_query_port;
 	dev->modify_port = rxe_modify_port;
 	dev->get_link_layer = rxe_get_link_layer;
-	dev->query_gid = rxe_query_gid;
 	dev->get_netdev = rxe_get_netdev;
-	dev->add_gid = rxe_add_gid;
-	dev->del_gid = rxe_del_gid;
 	dev->query_pkey = rxe_query_pkey;
 	dev->alloc_ucontext = rxe_alloc_ucontext;
 	dev->dealloc_ucontext = rxe_dealloc_ucontext;
@@ -1298,6 +1296,7 @@ int rxe_register_device(struct rxe_dev *rxe)
 	}
 	rxe->tfm = tfm;
 
+	dev->driver_id = RDMA_DRIVER_RXE;
 	err = ib_register_device(dev, NULL);
 	if (err) {
 		pr_warn("%s failed with error %d\n", __func__, err);