summary refs log tree commit diff
path: root/net/tls/tls_main.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r--net/tls/tls_main.c33
1 files changed, 28 insertions, 5 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 48f1c26459d0..f208f8455ef2 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -328,7 +328,10 @@ static void tls_sk_proto_unhash(struct sock *sk)
 
 	ctx = tls_get_ctx(sk);
 	tls_sk_proto_cleanup(sk, ctx, timeo);
+	write_lock_bh(&sk->sk_callback_lock);
 	icsk->icsk_ulp_data = NULL;
+	sk->sk_prot = ctx->sk_proto;
+	write_unlock_bh(&sk->sk_callback_lock);
 
 	if (ctx->sk_proto->unhash)
 		ctx->sk_proto->unhash(sk);
@@ -337,7 +340,7 @@ static void tls_sk_proto_unhash(struct sock *sk)
 
 static void tls_sk_proto_close(struct sock *sk, long timeout)
 {
-	void (*sk_proto_close)(struct sock *sk, long timeout);
+	struct inet_connection_sock *icsk = inet_csk(sk);
 	struct tls_context *ctx = tls_get_ctx(sk);
 	long timeo = sock_sndtimeo(sk, 0);
 	bool free_ctx;
@@ -347,12 +350,15 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 
 	lock_sock(sk);
 	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
-	sk_proto_close = ctx->sk_proto_close;
 
 	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
 		tls_sk_proto_cleanup(sk, ctx, timeo);
 
+	write_lock_bh(&sk->sk_callback_lock);
+	if (free_ctx)
+		icsk->icsk_ulp_data = NULL;
 	sk->sk_prot = ctx->sk_proto;
+	write_unlock_bh(&sk->sk_callback_lock);
 	release_sock(sk);
 	if (ctx->tx_conf == TLS_SW)
 		tls_sw_free_ctx_tx(ctx);
@@ -360,7 +366,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 		tls_sw_strparser_done(ctx);
 	if (ctx->rx_conf == TLS_SW)
 		tls_sw_free_ctx_rx(ctx);
-	sk_proto_close(sk, timeout);
+	ctx->sk_proto_close(sk, timeout);
 
 	if (free_ctx)
 		tls_ctx_free(ctx);
@@ -827,7 +833,7 @@ static int tls_init(struct sock *sk)
 	int rc = 0;
 
 	if (tls_hw_prot(sk))
-		goto out;
+		return 0;
 
 	/* The TLS ulp is currently supported only for TCP sockets
 	 * in ESTABLISHED state.
@@ -838,22 +844,38 @@ static int tls_init(struct sock *sk)
 	if (sk->sk_state != TCP_ESTABLISHED)
 		return -ENOTSUPP;
 
+	tls_build_proto(sk);
+
 	/* allocate tls context */
+	write_lock_bh(&sk->sk_callback_lock);
 	ctx = create_ctx(sk);
 	if (!ctx) {
 		rc = -ENOMEM;
 		goto out;
 	}
 
-	tls_build_proto(sk);
 	ctx->tx_conf = TLS_BASE;
 	ctx->rx_conf = TLS_BASE;
 	ctx->sk_proto = sk->sk_prot;
 	update_sk_prot(sk, ctx);
 out:
+	write_unlock_bh(&sk->sk_callback_lock);
 	return rc;
 }
 
+static void tls_update(struct sock *sk, struct proto *p)
+{
+	struct tls_context *ctx;
+
+	ctx = tls_get_ctx(sk);
+	if (likely(ctx)) {
+		ctx->sk_proto_close = p->close;
+		ctx->sk_proto = p;
+	} else {
+		sk->sk_prot = p;
+	}
+}
+
 void tls_register_device(struct tls_device *device)
 {
 	spin_lock_bh(&device_spinlock);
@@ -874,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 	.name			= "tls",
 	.owner			= THIS_MODULE,
 	.init			= tls_init,
+	.update			= tls_update,
 };
 
 static int __init tls_register(void)