summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--include/linux/skmsg.h41
-rw-r--r--net/tls/tls_sw.c439
2 files changed, 414 insertions, 66 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 4e84b3c2eff8..0b919f0bc6d6 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -29,7 +29,11 @@ struct sk_msg_sg {
 	u32				size;
 	u32				copybreak;
 	bool				copy[MAX_MSG_FRAGS];
-	struct scatterlist		data[MAX_MSG_FRAGS];
+	/* The extra element is used for chaining the front and sections when
+	 * the list becomes partitioned (e.g. end < start). The crypto APIs
+	 * require the chaining.
+	 */
+	struct scatterlist		data[MAX_MSG_FRAGS + 1];
 };
 
 struct sk_msg {
@@ -112,6 +116,7 @@ void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
 				  u32 bytes);
 
 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
+void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
 
 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 			      struct sk_msg *msg, u32 bytes);
@@ -161,8 +166,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)
 
 static inline void sk_msg_init(struct sk_msg *msg)
 {
+	BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
 	memset(msg, 0, sizeof(*msg));
-	sg_init_marker(msg->sg.data, ARRAY_SIZE(msg->sg.data));
+	sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
 }
 
 static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
@@ -174,6 +180,12 @@ static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
 	src->sg.data[which].offset += size;
 }
 
+static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
+{
+	memcpy(dst, src, sizeof(*src));
+	sk_msg_init(src);
+}
+
 static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
 {
 	return msg->sg.end >= msg->sg.start ?
@@ -229,6 +241,26 @@ static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
 	sk_msg_iter_next(msg, end);
 }
 
+static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
+{
+	do {
+		msg->sg.copy[i] = copy_state;
+		sk_msg_iter_var_next(i);
+		if (i == msg->sg.end)
+			break;
+	} while (1);
+}
+
+static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
+{
+	sk_msg_sg_copy(msg, start, true);
+}
+
+static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
+{
+	sk_msg_sg_copy(msg, start, false);
+}
+
 static inline struct sk_psock *sk_psock(const struct sock *sk)
 {
 	return rcu_dereference_sk_user_data(sk);
@@ -245,6 +277,11 @@ static inline void sk_psock_queue_msg(struct sk_psock *psock,
 	list_add_tail(&msg->list, &psock->ingress_msg);
 }
 
+static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
+{
+	return psock ? list_empty(&psock->ingress_msg) : true;
+}
+
 static inline void sk_psock_report_error(struct sk_psock *psock, int err)
 {
 	struct sock *sk = psock->sk;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 3b75e0dd51a2..a525fc4c2a4b 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -4,6 +4,7 @@
  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
  *
  * This software is available to you under a choice of one of two
  * licenses.  You may choose to be licensed under the terms of the GNU
@@ -258,21 +259,58 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
 	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
 }
 
-static void tls_free_open_rec(struct sock *sk)
+static struct tls_rec *tls_get_rec(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_pl, *msg_en;
+	struct tls_rec *rec;
+	int mem_size;
 
-	/* Return if there is no open record */
+	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+
+	rec = kzalloc(mem_size, sk->sk_allocation);
 	if (!rec)
-		return;
+		return NULL;
 
+	msg_pl = &rec->msg_plaintext;
+	msg_en = &rec->msg_encrypted;
+
+	sk_msg_init(msg_pl);
+	sk_msg_init(msg_en);
+
+	sg_init_table(rec->sg_aead_in, 2);
+	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
+		   sizeof(rec->aad_space));
+	sg_unmark_end(&rec->sg_aead_in[1]);
+
+	sg_init_table(rec->sg_aead_out, 2);
+	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
+		   sizeof(rec->aad_space));
+	sg_unmark_end(&rec->sg_aead_out[1]);
+
+	return rec;
+}
+
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
+{
 	sk_msg_free(sk, &rec->msg_encrypted);
 	sk_msg_free(sk, &rec->msg_plaintext);
 	kfree(rec);
 }
 
+static void tls_free_open_rec(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
+
+	if (rec) {
+		tls_free_rec(sk, rec);
+		ctx->open_rec = NULL;
+	}
+}
+
 int tls_tx_records(struct sock *sk, int flags)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -439,16 +477,135 @@ static int tls_do_encryption(struct sock *sk,
 	return rc;
 }
 
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
+				 struct tls_rec **to, struct sk_msg *msg_opl,
+				 struct sk_msg *msg_oen, u32 split_point,
+				 u32 tx_overhead_size, u32 *orig_end)
+{
+	u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
+	struct scatterlist *sge, *osge, *nsge;
+	u32 orig_size = msg_opl->sg.size;
+	struct scatterlist tmp = { };
+	struct sk_msg *msg_npl;
+	struct tls_rec *new;
+	int ret;
+
+	new = tls_get_rec(sk);
+	if (!new)
+		return -ENOMEM;
+	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
+			   tx_overhead_size, 0);
+	if (ret < 0) {
+		tls_free_rec(sk, new);
+		return ret;
+	}
+
+	*orig_end = msg_opl->sg.end;
+	i = msg_opl->sg.start;
+	sge = sk_msg_elem(msg_opl, i);
+	while (apply && sge->length) {
+		if (sge->length > apply) {
+			u32 len = sge->length - apply;
+
+			get_page(sg_page(sge));
+			sg_set_page(&tmp, sg_page(sge), len,
+				    sge->offset + apply);
+			sge->length = apply;
+			bytes += apply;
+			apply = 0;
+		} else {
+			apply -= sge->length;
+			bytes += sge->length;
+		}
+
+		sk_msg_iter_var_next(i);
+		if (i == msg_opl->sg.end)
+			break;
+		sge = sk_msg_elem(msg_opl, i);
+	}
+
+	msg_opl->sg.end = i;
+	msg_opl->sg.curr = i;
+	msg_opl->sg.copybreak = 0;
+	msg_opl->apply_bytes = 0;
+	msg_opl->sg.size = bytes;
+
+	msg_npl = &new->msg_plaintext;
+	msg_npl->apply_bytes = apply;
+	msg_npl->sg.size = orig_size - bytes;
+
+	j = msg_npl->sg.start;
+	nsge = sk_msg_elem(msg_npl, j);
+	if (tmp.length) {
+		memcpy(nsge, &tmp, sizeof(*nsge));
+		sk_msg_iter_var_next(j);
+		nsge = sk_msg_elem(msg_npl, j);
+	}
+
+	osge = sk_msg_elem(msg_opl, i);
+	while (osge->length) {
+		memcpy(nsge, osge, sizeof(*nsge));
+		sg_unmark_end(nsge);
+		sk_msg_iter_var_next(i);
+		sk_msg_iter_var_next(j);
+		if (i == *orig_end)
+			break;
+		osge = sk_msg_elem(msg_opl, i);
+		nsge = sk_msg_elem(msg_npl, j);
+	}
+
+	msg_npl->sg.end = j;
+	msg_npl->sg.curr = j;
+	msg_npl->sg.copybreak = 0;
+
+	*to = new;
+	return 0;
+}
+
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
+				  struct tls_rec *from, u32 orig_end)
+{
+	struct sk_msg *msg_npl = &from->msg_plaintext;
+	struct sk_msg *msg_opl = &to->msg_plaintext;
+	struct scatterlist *osge, *nsge;
+	u32 i, j;
+
+	i = msg_opl->sg.end;
+	sk_msg_iter_var_prev(i);
+	j = msg_npl->sg.start;
+
+	osge = sk_msg_elem(msg_opl, i);
+	nsge = sk_msg_elem(msg_npl, j);
+
+	if (sg_page(osge) == sg_page(nsge) &&
+	    osge->offset + osge->length == nsge->offset) {
+		osge->length += nsge->length;
+		put_page(sg_page(nsge));
+	}
+
+	msg_opl->sg.end = orig_end;
+	msg_opl->sg.curr = orig_end;
+	msg_opl->sg.copybreak = 0;
+	msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
+	msg_opl->sg.size += msg_npl->sg.size;
+
+	sk_msg_free(sk, &to->msg_encrypted);
+	sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
+
+	kfree(from);
+}
+
 static int tls_push_record(struct sock *sk, int flags,
 			   unsigned char record_type)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	struct tls_rec *rec = ctx->open_rec;
+	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
+	u32 i, split_point, uninitialized_var(orig_end);
 	struct sk_msg *msg_pl, *msg_en;
 	struct aead_request *req;
+	bool split;
 	int rc;
-	u32 i;
 
 	if (!rec)
 		return 0;
@@ -456,6 +613,18 @@ static int tls_push_record(struct sock *sk, int flags,
 	msg_pl = &rec->msg_plaintext;
 	msg_en = &rec->msg_encrypted;
 
+	split_point = msg_pl->apply_bytes;
+	split = split_point && split_point < msg_pl->sg.size;
+	if (split) {
+		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
+					   split_point, tls_ctx->tx.overhead_size,
+					   &orig_end);
+		if (rc < 0)
+			return rc;
+		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+			    tls_ctx->tx.overhead_size);
+	}
+
 	rec->tx_flags = flags;
 	req = &rec->aead_req;
 
@@ -487,57 +656,139 @@ static int tls_push_record(struct sock *sk, int flags,
 
 	rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i);
 	if (rc < 0) {
-		if (rc != -EINPROGRESS)
+		if (rc != -EINPROGRESS) {
 			tls_err_abort(sk, EBADMSG);
+			if (split) {
+				tls_ctx->pending_open_record_frags = true;
+				tls_merge_open_record(sk, rec, tmp, orig_end);
+			}
+		}
 		return rc;
+	} else if (split) {
+		msg_pl = &tmp->msg_plaintext;
+		msg_en = &tmp->msg_encrypted;
+		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+			    tls_ctx->tx.overhead_size);
+		tls_ctx->pending_open_record_frags = true;
+		ctx->open_rec = tmp;
 	}
 
 	return tls_tx_records(sk, flags);
 }
 
-static int tls_sw_push_pending_record(struct sock *sk, int flags)
-{
-	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
-}
-
-static struct tls_rec *get_rec(struct sock *sk)
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
+			       bool full_record, u8 record_type,
+			       size_t *copied, int flags)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	struct sk_msg *msg_pl, *msg_en;
+	struct sk_msg msg_redir = { };
+	struct sk_psock *psock;
+	struct sock *sk_redir;
 	struct tls_rec *rec;
-	int mem_size;
+	int err = 0, send;
+	bool enospc;
+
+	psock = sk_psock_get(sk);
+	if (!psock)
+		return tls_push_record(sk, flags, record_type);
+more_data:
+	enospc = sk_msg_full(msg);
+	if (psock->eval == __SK_NONE)
+		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
+	    !enospc && !full_record) {
+		err = -ENOSPC;
+		goto out_err;
+	}
+	msg->cork_bytes = 0;
+	send = msg->sg.size;
+	if (msg->apply_bytes && msg->apply_bytes < send)
+		send = msg->apply_bytes;
+
+	switch (psock->eval) {
+	case __SK_PASS:
+		err = tls_push_record(sk, flags, record_type);
+		if (err < 0) {
+			*copied -= sk_msg_free(sk, msg);
+			tls_free_open_rec(sk);
+			goto out_err;
+		}
+		break;
+	case __SK_REDIRECT:
+		sk_redir = psock->sk_redir;
+		memcpy(&msg_redir, msg, sizeof(*msg));
+		if (msg->apply_bytes < send)
+			msg->apply_bytes = 0;
+		else
+			msg->apply_bytes -= send;
+		sk_msg_return_zero(sk, msg, send);
+		msg->sg.size -= send;
+		release_sock(sk);
+		err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
+		lock_sock(sk);
+		if (err < 0) {
+			*copied -= sk_msg_free_nocharge(sk, &msg_redir);
+			msg->sg.size = 0;
+		}
+		if (msg->sg.size == 0)
+			tls_free_open_rec(sk);
+		break;
+	case __SK_DROP:
+	default:
+		sk_msg_free_partial(sk, msg, send);
+		if (msg->apply_bytes < send)
+			msg->apply_bytes = 0;
+		else
+			msg->apply_bytes -= send;
+		if (msg->sg.size == 0)
+			tls_free_open_rec(sk);
+		*copied -= send;
+		err = -EACCES;
+	}
 
-	/* Return if we already have an open record */
-	if (ctx->open_rec)
-		return ctx->open_rec;
+	if (likely(!err)) {
+		bool reset_eval = !ctx->open_rec;
 
-	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+		rec = ctx->open_rec;
+		if (rec) {
+			msg = &rec->msg_plaintext;
+			if (!msg->apply_bytes)
+				reset_eval = true;
+		}
+		if (reset_eval) {
+			psock->eval = __SK_NONE;
+			if (psock->sk_redir) {
+				sock_put(psock->sk_redir);
+				psock->sk_redir = NULL;
+			}
+		}
+		if (rec)
+			goto more_data;
+	}
+ out_err:
+	sk_psock_put(sk, psock);
+	return err;
+}
+
+static int tls_sw_push_pending_record(struct sock *sk, int flags)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_pl;
+	size_t copied;
 
-	rec = kzalloc(mem_size, sk->sk_allocation);
 	if (!rec)
-		return NULL;
+		return 0;
 
 	msg_pl = &rec->msg_plaintext;
-	msg_en = &rec->msg_encrypted;
-
-	sk_msg_init(msg_pl);
-	sk_msg_init(msg_en);
-
-	sg_init_table(rec->sg_aead_in, 2);
-	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
-		   sizeof(rec->aad_space));
-	sg_unmark_end(&rec->sg_aead_in[1]);
-
-	sg_init_table(rec->sg_aead_out, 2);
-	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
-		   sizeof(rec->aad_space));
-	sg_unmark_end(&rec->sg_aead_out[1]);
-
-	ctx->open_rec = rec;
-	rec->inplace_crypto = 1;
+	copied = msg_pl->sg.size;
+	if (!copied)
+		return 0;
 
-	return rec;
+	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
+				   &copied, flags);
 }
 
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
@@ -589,7 +840,10 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 			goto send_end;
 		}
 
-		rec = get_rec(sk);
+		if (ctx->open_rec)
+			rec = ctx->open_rec;
+		else
+			rec = ctx->open_rec = tls_get_rec(sk);
 		if (!rec) {
 			ret = -ENOMEM;
 			goto send_end;
@@ -628,6 +882,8 @@ alloc_encrypted:
 		}
 
 		if (!is_kvec && (full_record || eor) && !async_capable) {
+			u32 first = msg_pl->sg.end;
+
 			ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
 							msg_pl, try_to_copy);
 			if (ret)
@@ -637,15 +893,27 @@ alloc_encrypted:
 
 			num_zc++;
 			copied += try_to_copy;
-			ret = tls_push_record(sk, msg->msg_flags, record_type);
+
+			sk_msg_sg_copy_set(msg_pl, first);
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied,
+						  msg->msg_flags);
 			if (ret) {
 				if (ret == -EINPROGRESS)
 					num_async++;
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ret == -ENOSPC)
+					goto rollback_iter;
 				else if (ret != -EAGAIN)
 					goto send_end;
 			}
 			continue;
-
+rollback_iter:
+			copied -= try_to_copy;
+			sk_msg_sg_copy_clear(msg_pl, first);
+			iov_iter_revert(&msg->msg_iter,
+					msg_pl->sg.size - orig_size);
 fallback_to_reg_send:
 			sk_msg_trim(sk, msg_pl, orig_size);
 		}
@@ -678,12 +946,19 @@ fallback_to_reg_send:
 		tls_ctx->pending_open_record_frags = true;
 		copied += try_to_copy;
 		if (full_record || eor) {
-			ret = tls_push_record(sk, msg->msg_flags, record_type);
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied,
+						  msg->msg_flags);
 			if (ret) {
 				if (ret == -EINPROGRESS)
 					num_async++;
-				else if (ret != -EAGAIN)
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ret != -EAGAIN) {
+					if (ret == -ENOSPC)
+						ret = 0;
 					goto send_end;
+				}
 			}
 		}
 
@@ -742,10 +1017,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
-	size_t orig_size = size;
 	struct sk_msg *msg_pl;
 	struct tls_rec *rec;
 	int num_async = 0;
+	size_t copied = 0;
 	bool full_record;
 	int record_room;
 	int ret = 0;
@@ -778,7 +1053,10 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
 			goto sendpage_end;
 		}
 
-		rec = get_rec(sk);
+		if (ctx->open_rec)
+			rec = ctx->open_rec;
+		else
+			rec = ctx->open_rec = tls_get_rec(sk);
 		if (!rec) {
 			ret = -ENOMEM;
 			goto sendpage_end;
@@ -788,6 +1066,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
 
 		full_record = false;
 		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+		copied = 0;
 		copy = size;
 		if (copy >= record_room) {
 			copy = record_room;
@@ -818,16 +1097,23 @@ alloc_payload:
 
 		offset += copy;
 		size -= copy;
+		copied += copy;
 
 		tls_ctx->pending_open_record_frags = true;
 		if (full_record || eor || sk_msg_full(msg_pl)) {
 			rec->inplace_crypto = 0;
-			ret = tls_push_record(sk, flags, record_type);
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied, flags);
 			if (ret) {
 				if (ret == -EINPROGRESS)
 					num_async++;
-				else if (ret != -EAGAIN)
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ret != -EAGAIN) {
+					if (ret == -ENOSPC)
+						ret = 0;
 					goto sendpage_end;
+				}
 			}
 		}
 		continue;
@@ -851,24 +1137,20 @@ wait_for_memory:
 		}
 	}
 sendpage_end:
-	if (orig_size > size)
-		ret = orig_size - size;
-	else
-		ret = sk_stream_error(sk, flags, ret);
-
+	ret = sk_stream_error(sk, flags, ret);
 	release_sock(sk);
-	return ret;
+	return copied ? copied : ret;
 }
 
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
-				     long timeo, int *err)
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
+				     int flags, long timeo, int *err)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	struct sk_buff *skb;
 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
-	while (!(skb = ctx->recv_pkt)) {
+	while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
 		if (sk->sk_err) {
 			*err = sock_error(sk);
 			return NULL;
@@ -887,7 +1169,10 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
 
 		add_wait_queue(sk_sleep(sk), &wait);
 		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
+		sk_wait_event(sk, &timeo,
+			      ctx->recv_pkt != skb ||
+			      !sk_psock_queue_empty(psock),
+			      &wait);
 		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 		remove_wait_queue(sk_sleep(sk), &wait);
 
@@ -1164,6 +1449,7 @@ int tls_sw_recvmsg(struct sock *sk,
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct sk_psock *psock;
 	unsigned char control;
 	struct strp_msg *rxm;
 	struct sk_buff *skb;
@@ -1179,6 +1465,7 @@ int tls_sw_recvmsg(struct sock *sk,
 	if (unlikely(flags & MSG_ERRQUEUE))
 		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
 
+	psock = sk_psock_get(sk);
 	lock_sock(sk);
 
 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
@@ -1188,9 +1475,19 @@ int tls_sw_recvmsg(struct sock *sk,
 		bool async = false;
 		int chunk = 0;
 
-		skb = tls_wait_data(sk, flags, timeo, &err);
-		if (!skb)
+		skb = tls_wait_data(sk, psock, flags, timeo, &err);
+		if (!skb) {
+			if (psock) {
+				int ret = __tcp_bpf_recvmsg(sk, psock, msg, len);
+
+				if (ret > 0) {
+					copied += ret;
+					len -= ret;
+					continue;
+				}
+			}
 			goto recv_end;
+		}
 
 		rxm = strp_msg(skb);
 
@@ -1296,6 +1593,8 @@ recv_end:
 	}
 
 	release_sock(sk);
+	if (psock)
+		sk_psock_put(sk, psock);
 	return copied ? : err;
 }
 
@@ -1318,7 +1617,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 
 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 
-	skb = tls_wait_data(sk, flags, timeo, &err);
+	skb = tls_wait_data(sk, NULL, flags, timeo, &err);
 	if (!skb)
 		goto splice_read_end;
 
@@ -1356,11 +1655,16 @@ bool tls_sw_stream_read(const struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	bool ingress_empty = true;
+	struct sk_psock *psock;
 
-	if (ctx->recv_pkt)
-		return true;
+	rcu_read_lock();
+	psock = sk_psock(sk);
+	if (psock)
+		ingress_empty = list_empty(&psock->ingress_msg);
+	rcu_read_unlock();
 
-	return false;
+	return !ingress_empty || ctx->recv_pkt;
 }
 
 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
@@ -1439,8 +1743,15 @@ static void tls_data_ready(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct sk_psock *psock;
 
 	strp_data_ready(&ctx->strp);
+
+	psock = sk_psock_get(sk);
+	if (psock && !list_empty(&psock->ingress_msg)) {
+		ctx->saved_data_ready(sk);
+		sk_psock_put(sk, psock);
+	}
 }
 
 void tls_sw_free_resources_tx(struct sock *sk)