summary refs log tree commit diff
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/Kconfig1
-rw-r--r--net/tls/tls_main.c62
-rw-r--r--net/tls/tls_sw.c587
3 files changed, 582 insertions, 68 deletions
diff --git a/net/tls/Kconfig b/net/tls/Kconfig
index eb583038c67e..89b8745a986f 100644
--- a/net/tls/Kconfig
+++ b/net/tls/Kconfig
@@ -7,6 +7,7 @@ config TLS
 	select CRYPTO
 	select CRYPTO_AES
 	select CRYPTO_GCM
+	select STREAM_PARSER
 	default n
 	---help---
 	Enable kernel support for TLS protocol. This allows symmetric
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index c405beeda765..6f5c1146da4a 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -54,12 +54,15 @@ enum {
 enum {
 	TLS_BASE,
 	TLS_SW_TX,
+	TLS_SW_RX,
+	TLS_SW_RXTX,
 	TLS_NUM_CONFIG,
 };
 
 static struct proto *saved_tcpv6_prot;
 static DEFINE_MUTEX(tcpv6_prot_mutex);
 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG];
+static struct proto_ops tls_sw_proto_ops;
 
 static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx)
 {
@@ -261,9 +264,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 
 	kfree(ctx->tx.rec_seq);
 	kfree(ctx->tx.iv);
+	kfree(ctx->rx.rec_seq);
+	kfree(ctx->rx.iv);
 
-	if (ctx->conf == TLS_SW_TX)
-		tls_sw_free_tx_resources(sk);
+	if (ctx->conf == TLS_SW_TX ||
+	    ctx->conf == TLS_SW_RX ||
+	    ctx->conf == TLS_SW_RXTX) {
+		tls_sw_free_resources(sk);
+	}
 
 skip_tx_cleanup:
 	release_sock(sk);
@@ -365,8 +373,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname,
 	return do_tls_getsockopt(sk, optname, optval, optlen);
 }
 
-static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
-				unsigned int optlen)
+static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
+				  unsigned int optlen, int tx)
 {
 	struct tls_crypto_info *crypto_info;
 	struct tls_context *ctx = tls_get_ctx(sk);
@@ -378,7 +386,11 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
 		goto out;
 	}
 
-	crypto_info = &ctx->crypto_send;
+	if (tx)
+		crypto_info = &ctx->crypto_send;
+	else
+		crypto_info = &ctx->crypto_recv;
+
 	/* Currently we don't support set crypto info more than one time */
 	if (TLS_CRYPTO_INFO_READY(crypto_info)) {
 		rc = -EBUSY;
@@ -417,15 +429,31 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval,
 	}
 
 	/* currently SW is default, we will have ethtool in future */
-	rc = tls_set_sw_offload(sk, ctx);
-	conf = TLS_SW_TX;
+	if (tx) {
+		rc = tls_set_sw_offload(sk, ctx, 1);
+		if (ctx->conf == TLS_SW_RX)
+			conf = TLS_SW_RXTX;
+		else
+			conf = TLS_SW_TX;
+	} else {
+		rc = tls_set_sw_offload(sk, ctx, 0);
+		if (ctx->conf == TLS_SW_TX)
+			conf = TLS_SW_RXTX;
+		else
+			conf = TLS_SW_RX;
+	}
+
 	if (rc)
 		goto err_crypto_info;
 
 	ctx->conf = conf;
 	update_sk_prot(sk, ctx);
-	ctx->sk_write_space = sk->sk_write_space;
-	sk->sk_write_space = tls_write_space;
+	if (tx) {
+		ctx->sk_write_space = sk->sk_write_space;
+		sk->sk_write_space = tls_write_space;
+	} else {
+		sk->sk_socket->ops = &tls_sw_proto_ops;
+	}
 	goto out;
 
 err_crypto_info:
@@ -441,8 +469,10 @@ static int do_tls_setsockopt(struct sock *sk, int optname,
 
 	switch (optname) {
 	case TLS_TX:
+	case TLS_RX:
 		lock_sock(sk);
-		rc = do_tls_setsockopt_tx(sk, optval, optlen);
+		rc = do_tls_setsockopt_conf(sk, optval, optlen,
+					    optname == TLS_TX);
 		release_sock(sk);
 		break;
 	default:
@@ -473,6 +503,14 @@ static void build_protos(struct proto *prot, struct proto *base)
 	prot[TLS_SW_TX] = prot[TLS_BASE];
 	prot[TLS_SW_TX].sendmsg		= tls_sw_sendmsg;
 	prot[TLS_SW_TX].sendpage	= tls_sw_sendpage;
+
+	prot[TLS_SW_RX] = prot[TLS_BASE];
+	prot[TLS_SW_RX].recvmsg		= tls_sw_recvmsg;
+	prot[TLS_SW_RX].close		= tls_sk_proto_close;
+
+	prot[TLS_SW_RXTX] = prot[TLS_SW_TX];
+	prot[TLS_SW_RXTX].recvmsg	= tls_sw_recvmsg;
+	prot[TLS_SW_RXTX].close		= tls_sk_proto_close;
 }
 
 static int tls_init(struct sock *sk)
@@ -531,6 +569,10 @@ static int __init tls_register(void)
 {
 	build_protos(tls_prots[TLSV4], &tcp_prot);
 
+	tls_sw_proto_ops = inet_stream_ops;
+	tls_sw_proto_ops.poll = tls_sw_poll;
+	tls_sw_proto_ops.splice_read = tls_sw_splice_read;
+
 	tcp_register_ulp(&tcp_tls_ulp_ops);
 
 	return 0;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 1c79d9ad1731..4dc766b03f00 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -34,11 +34,60 @@
  * SOFTWARE.
  */
 
+#include <linux/sched/signal.h>
 #include <linux/module.h>
 #include <crypto/aead.h>
 
+#include <net/strparser.h>
 #include <net/tls.h>
 
+static int tls_do_decryption(struct sock *sk,
+			     struct scatterlist *sgin,
+			     struct scatterlist *sgout,
+			     char *iv_recv,
+			     size_t data_len,
+			     struct sk_buff *skb,
+			     gfp_t flags)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	struct strp_msg *rxm = strp_msg(skb);
+	struct aead_request *aead_req;
+
+	int ret;
+	unsigned int req_size = sizeof(struct aead_request) +
+		crypto_aead_reqsize(ctx->aead_recv);
+
+	aead_req = kzalloc(req_size, flags);
+	if (!aead_req)
+		return -ENOMEM;
+
+	aead_request_set_tfm(aead_req, ctx->aead_recv);
+	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+	aead_request_set_crypt(aead_req, sgin, sgout,
+			       data_len + tls_ctx->rx.tag_size,
+			       (u8 *)iv_recv);
+	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
+				  crypto_req_done, &ctx->async_wait);
+
+	ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
+
+	if (ret < 0)
+		goto out;
+
+	rxm->offset += tls_ctx->rx.prepend_size;
+	rxm->full_len -= tls_ctx->rx.overhead_size;
+	tls_advance_record_sn(sk, &tls_ctx->rx);
+
+	ctx->decrypted = true;
+
+	ctx->saved_data_ready(sk);
+
+out:
+	kfree(aead_req);
+	return ret;
+}
+
 static void trim_sg(struct sock *sk, struct scatterlist *sg,
 		    int *sg_num_elem, unsigned int *sg_size, int target_size)
 {
@@ -581,13 +630,404 @@ sendpage_end:
 	return ret;
 }
 
-void tls_sw_free_tx_resources(struct sock *sk)
+static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
+				     long timeo, int *err)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	struct sk_buff *skb;
+	DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+	while (!(skb = ctx->recv_pkt)) {
+		if (sk->sk_err) {
+			*err = sock_error(sk);
+			return NULL;
+		}
+
+		if (sock_flag(sk, SOCK_DONE))
+			return NULL;
+
+		if ((flags & MSG_DONTWAIT) || !timeo) {
+			*err = -EAGAIN;
+			return NULL;
+		}
+
+		add_wait_queue(sk_sleep(sk), &wait);
+		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
+		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+		remove_wait_queue(sk_sleep(sk), &wait);
+
+		/* Handle signals */
+		if (signal_pending(current)) {
+			*err = sock_intr_errno(timeo);
+			return NULL;
+		}
+	}
+
+	return skb;
+}
+
+static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+		       struct scatterlist *sgout)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + tls_ctx->rx.iv_size];
+	struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
+	struct scatterlist *sgin = &sgin_arr[0];
+	struct strp_msg *rxm = strp_msg(skb);
+	int ret, nsg = ARRAY_SIZE(sgin_arr);
+	char aad_recv[TLS_AAD_SPACE_SIZE];
+	struct sk_buff *unused;
+
+	ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
+			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+			    tls_ctx->rx.iv_size);
+	if (ret < 0)
+		return ret;
+
+	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	if (!sgout) {
+		nsg = skb_cow_data(skb, 0, &unused) + 1;
+		sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
+		if (!sgout)
+			sgout = sgin;
+	}
+
+	sg_init_table(sgin, nsg);
+	sg_set_buf(&sgin[0], aad_recv, sizeof(aad_recv));
+
+	nsg = skb_to_sgvec(skb, &sgin[1],
+			   rxm->offset + tls_ctx->rx.prepend_size,
+			   rxm->full_len - tls_ctx->rx.prepend_size);
+
+	tls_make_aad(aad_recv,
+		     rxm->full_len - tls_ctx->rx.overhead_size,
+		     tls_ctx->rx.rec_seq,
+		     tls_ctx->rx.rec_seq_size,
+		     ctx->control);
+
+	ret = tls_do_decryption(sk, sgin, sgout, iv,
+				rxm->full_len - tls_ctx->rx.overhead_size,
+				skb, sk->sk_allocation);
+
+	if (sgin != &sgin_arr[0])
+		kfree(sgin);
+
+	return ret;
+}
+
+static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
+			       unsigned int len)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	struct strp_msg *rxm = strp_msg(skb);
+
+	if (len < rxm->full_len) {
+		rxm->offset += len;
+		rxm->full_len -= len;
+
+		return false;
+	}
+
+	/* Finished with message */
+	ctx->recv_pkt = NULL;
+	kfree_skb(skb);
+	strp_unpause(&ctx->strp);
+
+	return true;
+}
+
+int tls_sw_recvmsg(struct sock *sk,
+		   struct msghdr *msg,
+		   size_t len,
+		   int nonblock,
+		   int flags,
+		   int *addr_len)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	unsigned char control;
+	struct strp_msg *rxm;
+	struct sk_buff *skb;
+	ssize_t copied = 0;
+	bool cmsg = false;
+	int err = 0;
+	long timeo;
+
+	flags |= nonblock;
+
+	if (unlikely(flags & MSG_ERRQUEUE))
+		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
+
+	lock_sock(sk);
+
+	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	do {
+		bool zc = false;
+		int chunk = 0;
+
+		skb = tls_wait_data(sk, flags, timeo, &err);
+		if (!skb)
+			goto recv_end;
+
+		rxm = strp_msg(skb);
+		if (!cmsg) {
+			int cerr;
+
+			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+					sizeof(ctx->control), &ctx->control);
+			cmsg = true;
+			control = ctx->control;
+			if (ctx->control != TLS_RECORD_TYPE_DATA) {
+				if (cerr || msg->msg_flags & MSG_CTRUNC) {
+					err = -EIO;
+					goto recv_end;
+				}
+			}
+		} else if (control != ctx->control) {
+			goto recv_end;
+		}
+
+		if (!ctx->decrypted) {
+			int page_count;
+			int to_copy;
+
+			page_count = iov_iter_npages(&msg->msg_iter,
+						     MAX_SKB_FRAGS);
+			to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
+			if (to_copy <= len && page_count < MAX_SKB_FRAGS &&
+			    likely(!(flags & MSG_PEEK)))  {
+				struct scatterlist sgin[MAX_SKB_FRAGS + 1];
+				char unused[21];
+				int pages = 0;
+
+				zc = true;
+				sg_init_table(sgin, MAX_SKB_FRAGS + 1);
+				sg_set_buf(&sgin[0], unused, 13);
+
+				err = zerocopy_from_iter(sk, &msg->msg_iter,
+							 to_copy, &pages,
+							 &chunk, &sgin[1],
+							 MAX_SKB_FRAGS,	false);
+				if (err < 0)
+					goto fallback_to_reg_recv;
+
+				err = decrypt_skb(sk, skb, sgin);
+				for (; pages > 0; pages--)
+					put_page(sg_page(&sgin[pages]));
+				if (err < 0) {
+					tls_err_abort(sk, EBADMSG);
+					goto recv_end;
+				}
+			} else {
+fallback_to_reg_recv:
+				err = decrypt_skb(sk, skb, NULL);
+				if (err < 0) {
+					tls_err_abort(sk, EBADMSG);
+					goto recv_end;
+				}
+			}
+			ctx->decrypted = true;
+		}
+
+		if (!zc) {
+			chunk = min_t(unsigned int, rxm->full_len, len);
+			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
+						    chunk);
+			if (err < 0)
+				goto recv_end;
+		}
+
+		copied += chunk;
+		len -= chunk;
+		if (likely(!(flags & MSG_PEEK))) {
+			u8 control = ctx->control;
+
+			if (tls_sw_advance_skb(sk, skb, chunk)) {
+				/* Return full control message to
+				 * userspace before trying to parse
+				 * another message type
+				 */
+				msg->msg_flags |= MSG_EOR;
+				if (control != TLS_RECORD_TYPE_DATA)
+					goto recv_end;
+			}
+		}
+	} while (len);
+
+recv_end:
+	release_sock(sk);
+	return copied ? : err;
+}
+
+ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
+			   struct pipe_inode_info *pipe,
+			   size_t len, unsigned int flags)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	struct strp_msg *rxm = NULL;
+	struct sock *sk = sock->sk;
+	struct sk_buff *skb;
+	ssize_t copied = 0;
+	int err = 0;
+	long timeo;
+	int chunk;
+
+	lock_sock(sk);
+
+	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+
+	skb = tls_wait_data(sk, flags, timeo, &err);
+	if (!skb)
+		goto splice_read_end;
+
+	/* splice does not support reading control messages */
+	if (ctx->control != TLS_RECORD_TYPE_DATA) {
+		err = -ENOTSUPP;
+		goto splice_read_end;
+	}
+
+	if (!ctx->decrypted) {
+		err = decrypt_skb(sk, skb, NULL);
+
+		if (err < 0) {
+			tls_err_abort(sk, EBADMSG);
+			goto splice_read_end;
+		}
+		ctx->decrypted = true;
+	}
+	rxm = strp_msg(skb);
+
+	chunk = min_t(unsigned int, rxm->full_len, len);
+	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
+	if (copied < 0)
+		goto splice_read_end;
+
+	if (likely(!(flags & MSG_PEEK)))
+		tls_sw_advance_skb(sk, skb, copied);
+
+splice_read_end:
+	release_sock(sk);
+	return copied ? : err;
+}
+
+unsigned int tls_sw_poll(struct file *file, struct socket *sock,
+			 struct poll_table_struct *wait)
+{
+	unsigned int ret;
+	struct sock *sk = sock->sk;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+
+	/* Grab POLLOUT and POLLHUP from the underlying socket */
+	ret = ctx->sk_poll(file, sock, wait);
+
+	/* Clear POLLIN bits, and set based on recv_pkt */
+	ret &= ~(POLLIN | POLLRDNORM);
+	if (ctx->recv_pkt)
+		ret |= POLLIN | POLLRDNORM;
+
+	return ret;
+}
+
+static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	char header[tls_ctx->rx.prepend_size];
+	struct strp_msg *rxm = strp_msg(skb);
+	size_t cipher_overhead;
+	size_t data_len = 0;
+	int ret;
+
+	/* Verify that we have a full TLS header, or wait for more data */
+	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
+		return 0;
+
+	/* Linearize header to local buffer */
+	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
+
+	if (ret < 0)
+		goto read_failure;
+
+	ctx->control = header[0];
+
+	data_len = ((header[4] & 0xFF) | (header[3] << 8));
+
+	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
+
+	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
+		ret = -EMSGSIZE;
+		goto read_failure;
+	}
+	if (data_len < cipher_overhead) {
+		ret = -EBADMSG;
+		goto read_failure;
+	}
+
+	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
+	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
+		ret = -EINVAL;
+		goto read_failure;
+	}
+
+	return data_len + TLS_HEADER_SIZE;
+
+read_failure:
+	tls_err_abort(strp->sk, ret);
+
+	return ret;
+}
+
+static void tls_queue(struct strparser *strp, struct sk_buff *skb)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+	struct strp_msg *rxm;
+
+	rxm = strp_msg(skb);
+
+	ctx->decrypted = false;
+
+	ctx->recv_pkt = skb;
+	strp_pause(strp);
+
+	strp->sk->sk_state_change(strp->sk);
+}
+
+static void tls_data_ready(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+
+	strp_data_ready(&ctx->strp);
+}
+
+void tls_sw_free_resources(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 
 	if (ctx->aead_send)
 		crypto_free_aead(ctx->aead_send);
+	if (ctx->aead_recv) {
+		if (ctx->recv_pkt) {
+			kfree_skb(ctx->recv_pkt);
+			ctx->recv_pkt = NULL;
+		}
+		crypto_free_aead(ctx->aead_recv);
+		strp_stop(&ctx->strp);
+		write_lock_bh(&sk->sk_callback_lock);
+		sk->sk_data_ready = ctx->saved_data_ready;
+		write_unlock_bh(&sk->sk_callback_lock);
+		release_sock(sk);
+		strp_done(&ctx->strp);
+		lock_sock(sk);
+	}
 
 	tls_free_both_sg(sk);
 
@@ -595,12 +1035,15 @@ void tls_sw_free_tx_resources(struct sock *sk)
 	kfree(tls_ctx);
 }
 
-int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
+int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
 	char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
 	struct tls_crypto_info *crypto_info;
 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
 	struct tls_sw_context *sw_ctx;
+	struct cipher_context *cctx;
+	struct crypto_aead **aead;
+	struct strp_callbacks cb;
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
 	char *iv, *rec_seq;
 	int rc = 0;
@@ -610,22 +1053,29 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 		goto out;
 	}
 
-	if (ctx->priv_ctx) {
-		rc = -EEXIST;
-		goto out;
-	}
-
-	sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
-	if (!sw_ctx) {
-		rc = -ENOMEM;
-		goto out;
+	if (!ctx->priv_ctx) {
+		sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
+		if (!sw_ctx) {
+			rc = -ENOMEM;
+			goto out;
+		}
+		crypto_init_wait(&sw_ctx->async_wait);
+	} else {
+		sw_ctx = ctx->priv_ctx;
 	}
 
-	crypto_init_wait(&sw_ctx->async_wait);
-
 	ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
 
-	crypto_info = &ctx->crypto_send;
+	if (tx) {
+		crypto_info = &ctx->crypto_send;
+		cctx = &ctx->tx;
+		aead = &sw_ctx->aead_send;
+	} else {
+		crypto_info = &ctx->crypto_recv;
+		cctx = &ctx->rx;
+		aead = &sw_ctx->aead_recv;
+	}
+
 	switch (crypto_info->cipher_type) {
 	case TLS_CIPHER_AES_GCM_128: {
 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
@@ -644,48 +1094,49 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 		goto free_priv;
 	}
 
-	ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
-	ctx->tx.tag_size = tag_size;
-	ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
-	ctx->tx.iv_size = iv_size;
-	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			     GFP_KERNEL);
-	if (!ctx->tx.iv) {
+	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
+	cctx->tag_size = tag_size;
+	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
+	cctx->iv_size = iv_size;
+	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+			   GFP_KERNEL);
+	if (!cctx->iv) {
 		rc = -ENOMEM;
 		goto free_priv;
 	}
-	memcpy(ctx->tx.iv, gcm_128_info->salt,
-	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
-	ctx->tx.rec_seq_size = rec_seq_size;
-	ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
-	if (!ctx->tx.rec_seq) {
+	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+	cctx->rec_seq_size = rec_seq_size;
+	cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+	if (!cctx->rec_seq) {
 		rc = -ENOMEM;
 		goto free_iv;
 	}
-	memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size);
-
-	sg_init_table(sw_ctx->sg_encrypted_data,
-		      ARRAY_SIZE(sw_ctx->sg_encrypted_data));
-	sg_init_table(sw_ctx->sg_plaintext_data,
-		      ARRAY_SIZE(sw_ctx->sg_plaintext_data));
-
-	sg_init_table(sw_ctx->sg_aead_in, 2);
-	sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
-		   sizeof(sw_ctx->aad_space));
-	sg_unmark_end(&sw_ctx->sg_aead_in[1]);
-	sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
-	sg_init_table(sw_ctx->sg_aead_out, 2);
-	sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
-		   sizeof(sw_ctx->aad_space));
-	sg_unmark_end(&sw_ctx->sg_aead_out[1]);
-	sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
-
-	if (!sw_ctx->aead_send) {
-		sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0);
-		if (IS_ERR(sw_ctx->aead_send)) {
-			rc = PTR_ERR(sw_ctx->aead_send);
-			sw_ctx->aead_send = NULL;
+	memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
+
+	if (tx) {
+		sg_init_table(sw_ctx->sg_encrypted_data,
+			      ARRAY_SIZE(sw_ctx->sg_encrypted_data));
+		sg_init_table(sw_ctx->sg_plaintext_data,
+			      ARRAY_SIZE(sw_ctx->sg_plaintext_data));
+
+		sg_init_table(sw_ctx->sg_aead_in, 2);
+		sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
+			   sizeof(sw_ctx->aad_space));
+		sg_unmark_end(&sw_ctx->sg_aead_in[1]);
+		sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
+		sg_init_table(sw_ctx->sg_aead_out, 2);
+		sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
+			   sizeof(sw_ctx->aad_space));
+		sg_unmark_end(&sw_ctx->sg_aead_out[1]);
+		sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
+	}
+
+	if (!*aead) {
+		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
+		if (IS_ERR(*aead)) {
+			rc = PTR_ERR(*aead);
+			*aead = NULL;
 			goto free_rec_seq;
 		}
 	}
@@ -694,21 +1145,41 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
 
 	memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
 
-	rc = crypto_aead_setkey(sw_ctx->aead_send, keyval,
+	rc = crypto_aead_setkey(*aead, keyval,
 				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
 	if (rc)
 		goto free_aead;
 
-	rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tx.tag_size);
-	if (!rc)
-		return 0;
+	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
+	if (rc)
+		goto free_aead;
+
+	if (!tx) {
+		/* Set up strparser */
+		memset(&cb, 0, sizeof(cb));
+		cb.rcv_msg = tls_queue;
+		cb.parse_msg = tls_read_size;
+
+		strp_init(&sw_ctx->strp, sk, &cb);
+
+		write_lock_bh(&sk->sk_callback_lock);
+		sw_ctx->saved_data_ready = sk->sk_data_ready;
+		sk->sk_data_ready = tls_data_ready;
+		write_unlock_bh(&sk->sk_callback_lock);
+
+		sw_ctx->sk_poll = sk->sk_socket->ops->poll;
+
+		strp_check_rcv(&sw_ctx->strp);
+	}
+
+	goto out;
 
 free_aead:
-	crypto_free_aead(sw_ctx->aead_send);
-	sw_ctx->aead_send = NULL;
+	crypto_free_aead(*aead);
+	*aead = NULL;
 free_rec_seq:
-	kfree(ctx->tx.rec_seq);
-	ctx->tx.rec_seq = NULL;
+	kfree(cctx->rec_seq);
+	cctx->rec_seq = NULL;
 free_iv:
 	kfree(ctx->tx.iv);
 	ctx->tx.iv = NULL;