summary refs log tree commit diff
path: root/drivers/usb/usbip/stub_dev.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/usb/usbip/stub_dev.c')
-rw-r--r--drivers/usb/usbip/stub_dev.c42
1 files changed, 35 insertions, 7 deletions
diff --git a/drivers/usb/usbip/stub_dev.c b/drivers/usb/usbip/stub_dev.c
index 2305d425e6c9..8f1de1fbbeed 100644
--- a/drivers/usb/usbip/stub_dev.c
+++ b/drivers/usb/usbip/stub_dev.c
@@ -46,6 +46,8 @@ static ssize_t usbip_sockfd_store(struct device *dev, struct device_attribute *a
 	int sockfd = 0;
 	struct socket *socket;
 	int rv;
+	struct task_struct *tcp_rx = NULL;
+	struct task_struct *tcp_tx = NULL;
 
 	if (!sdev) {
 		dev_err(dev, "sdev is null\n");
@@ -69,23 +71,47 @@ static ssize_t usbip_sockfd_store(struct device *dev, struct device_attribute *a
 		}
 
 		socket = sockfd_lookup(sockfd, &err);
-		if (!socket)
+		if (!socket) {
+			dev_err(dev, "failed to lookup sock");
 			goto err;
+		}
 
-		sdev->ud.tcp_socket = socket;
-		sdev->ud.sockfd = sockfd;
+		if (socket->type != SOCK_STREAM) {
+			dev_err(dev, "Expecting SOCK_STREAM - found %d",
+				socket->type);
+			goto sock_err;
+		}
 
+		/* unlock and create threads and get tasks */
 		spin_unlock_irq(&sdev->ud.lock);
+		tcp_rx = kthread_create(stub_rx_loop, &sdev->ud, "stub_rx");
+		if (IS_ERR(tcp_rx)) {
+			sockfd_put(socket);
+			return -EINVAL;
+		}
+		tcp_tx = kthread_create(stub_tx_loop, &sdev->ud, "stub_tx");
+		if (IS_ERR(tcp_tx)) {
+			kthread_stop(tcp_rx);
+			sockfd_put(socket);
+			return -EINVAL;
+		}
 
-		sdev->ud.tcp_rx = kthread_get_run(stub_rx_loop, &sdev->ud,
-						  "stub_rx");
-		sdev->ud.tcp_tx = kthread_get_run(stub_tx_loop, &sdev->ud,
-						  "stub_tx");
+		/* get task structs now */
+		get_task_struct(tcp_rx);
+		get_task_struct(tcp_tx);
 
+		/* lock and update sdev->ud state */
 		spin_lock_irq(&sdev->ud.lock);
+		sdev->ud.tcp_socket = socket;
+		sdev->ud.sockfd = sockfd;
+		sdev->ud.tcp_rx = tcp_rx;
+		sdev->ud.tcp_tx = tcp_tx;
 		sdev->ud.status = SDEV_ST_USED;
 		spin_unlock_irq(&sdev->ud.lock);
 
+		wake_up_process(sdev->ud.tcp_rx);
+		wake_up_process(sdev->ud.tcp_tx);
+
 	} else {
 		dev_info(dev, "stub down\n");
 
@@ -100,6 +126,8 @@ static ssize_t usbip_sockfd_store(struct device *dev, struct device_attribute *a
 
 	return count;
 
+sock_err:
+	sockfd_put(socket);
 err:
 	spin_unlock_irq(&sdev->ud.lock);
 	return -EINVAL;