summary refs log tree commit diff
path: root/net/vmw_vsock/af_vsock.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/af_vsock.c')
-rw-r--r--net/vmw_vsock/af_vsock.c19
1 files changed, 13 insertions, 6 deletions
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index bbe65dcb9738..7fd1220fbfa0 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1557,6 +1557,8 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 	if (err < 0)
 		goto out;
 
+	prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+
 	while (total_written < len) {
 		ssize_t written;
 
@@ -1576,9 +1578,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 				goto out_wait;
 
 			release_sock(sk);
-			prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
 			timeout = schedule_timeout(timeout);
-			finish_wait(sk_sleep(sk), &wait);
 			lock_sock(sk);
 			if (signal_pending(current)) {
 				err = sock_intr_errno(timeout);
@@ -1588,6 +1588,8 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 				goto out_wait;
 			}
 
+			prepare_to_wait(sk_sleep(sk), &wait,
+					TASK_INTERRUPTIBLE);
 		}
 
 		/* These checks occur both as part of and after the loop
@@ -1633,6 +1635,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 out_wait:
 	if (total_written > 0)
 		err = total_written;
+	finish_wait(sk_sleep(sk), &wait);
 out:
 	release_sock(sk);
 	return err;
@@ -1713,6 +1716,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	if (err < 0)
 		goto out;
 
+	prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
 
 	while (1) {
 		s64 ready = vsock_stream_has_data(vsk);
@@ -1723,7 +1727,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 			 */
 
 			err = -ENOMEM;
-			goto out;
+			goto out_wait;
 		} else if (ready > 0) {
 			ssize_t read;
 
@@ -1746,7 +1750,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 					vsk, target, read,
 					!(flags & MSG_PEEK), &recv_data);
 			if (err < 0)
-				goto out;
+				goto out_wait;
 
 			if (read >= target || flags & MSG_PEEK)
 				break;
@@ -1769,9 +1773,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 				break;
 
 			release_sock(sk);
-			prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
 			timeout = schedule_timeout(timeout);
-			finish_wait(sk_sleep(sk), &wait);
 			lock_sock(sk);
 
 			if (signal_pending(current)) {
@@ -1781,6 +1783,9 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 				err = -EAGAIN;
 				break;
 			}
+
+			prepare_to_wait(sk_sleep(sk), &wait,
+					TASK_INTERRUPTIBLE);
 		}
 	}
 
@@ -1811,6 +1816,8 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 		err = copied;
 	}
 
+out_wait:
+	finish_wait(sk_sleep(sk), &wait);
 out:
 	release_sock(sk);
 	return err;