summary refs log tree commit diff
path: root/net/vmw_vsock
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock')
-rw-r--r--net/vmw_vsock/hyperv_transport.c21
1 files changed, 17 insertions, 4 deletions
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index bbac023e70d1..5583df708b8c 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -310,11 +310,15 @@ static void hvs_close_connection(struct vmbus_channel *chan)
 	struct sock *sk = get_per_channel_state(chan);
 	struct vsock_sock *vsk = vsock_sk(sk);
 
+	lock_sock(sk);
+
 	sk->sk_state = TCP_CLOSE;
 	sock_set_flag(sk, SOCK_DONE);
 	vsk->peer_shutdown |= SEND_SHUTDOWN | RCV_SHUTDOWN;
 
 	sk->sk_state_change(sk);
+
+	release_sock(sk);
 }
 
 static void hvs_open_connection(struct vmbus_channel *chan)
@@ -344,6 +348,7 @@ static void hvs_open_connection(struct vmbus_channel *chan)
 	if (!sk)
 		return;
 
+	lock_sock(sk);
 	if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
 	    (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
 		goto out;
@@ -395,9 +400,7 @@ static void hvs_open_connection(struct vmbus_channel *chan)
 
 		vsock_insert_connected(vnew);
 
-		lock_sock(sk);
 		vsock_enqueue_accept(sk, new);
-		release_sock(sk);
 	} else {
 		sk->sk_state = TCP_ESTABLISHED;
 		sk->sk_socket->state = SS_CONNECTED;
@@ -410,6 +413,8 @@ static void hvs_open_connection(struct vmbus_channel *chan)
 out:
 	/* Release refcnt obtained when we called vsock_find_bound_socket() */
 	sock_put(sk);
+
+	release_sock(sk);
 }
 
 static u32 hvs_get_local_cid(void)
@@ -476,13 +481,21 @@ out:
 
 static void hvs_release(struct vsock_sock *vsk)
 {
+	struct sock *sk = sk_vsock(vsk);
 	struct hvsock *hvs = vsk->trans;
-	struct vmbus_channel *chan = hvs->chan;
+	struct vmbus_channel *chan;
+
+	lock_sock(sk);
+
+	sk->sk_state = SS_DISCONNECTING;
+	vsock_remove_sock(vsk);
+
+	release_sock(sk);
 
+	chan = hvs->chan;
 	if (chan)
 		hvs_shutdown(vsk, RCV_SHUTDOWN | SEND_SHUTDOWN);
 
-	vsock_remove_sock(vsk);
 }
 
 static void hvs_destruct(struct vsock_sock *vsk)