summary refs log tree commit diff
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/tls/Kconfig10
-rw-r--r--net/tls/Makefile2
-rw-r--r--net/tls/tls_device.c764
-rw-r--r--net/tls/tls_device_fallback.c450
-rw-r--r--net/tls/tls_main.c42
5 files changed, 1265 insertions, 3 deletions
diff --git a/net/tls/Kconfig b/net/tls/Kconfig
index 89b8745a986f..73f05ece53d0 100644
--- a/net/tls/Kconfig
+++ b/net/tls/Kconfig
@@ -14,3 +14,13 @@ config TLS
 	encryption handling of the TLS protocol to be done in-kernel.
 
 	If unsure, say N.
+
+config TLS_DEVICE
+	bool "Transport Layer Security HW offload"
+	depends on TLS
+	select SOCK_VALIDATE_XMIT
+	default n
+	help
+	Enable kernel support for HW offload of the TLS protocol.
+
+	If unsure, say N.
diff --git a/net/tls/Makefile b/net/tls/Makefile
index a930fd1c4f7b..4d6b728a67d0 100644
--- a/net/tls/Makefile
+++ b/net/tls/Makefile
@@ -5,3 +5,5 @@
 obj-$(CONFIG_TLS) += tls.o
 
 tls-y := tls_main.o tls_sw.o
+
+tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
new file mode 100644
index 000000000000..ac45d6224e41
--- /dev/null
+++ b/net/tls/tls_device.c
@@ -0,0 +1,764 @@
+/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses.  You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the
+ * OpenIB.org BSD license below:
+ *
+ *     Redistribution and use in source and binary forms, with or
+ *     without modification, are permitted provided that the following
+ *     conditions are met:
+ *
+ *      - Redistributions of source code must retain the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer.
+ *
+ *      - Redistributions in binary form must reproduce the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer in the documentation and/or other materials
+ *        provided with the distribution.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include <crypto/aead.h>
+#include <linux/highmem.h>
+#include <linux/module.h>
+#include <linux/netdevice.h>
+#include <net/dst.h>
+#include <net/inet_connection_sock.h>
+#include <net/tcp.h>
+#include <net/tls.h>
+
+/* device_offload_lock is used to synchronize tls_dev_add
+ * against NETDEV_DOWN notifications.
+ */
+static DECLARE_RWSEM(device_offload_lock);
+
+static void tls_device_gc_task(struct work_struct *work);
+
+static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
+static LIST_HEAD(tls_device_gc_list);
+static LIST_HEAD(tls_device_list);
+static DEFINE_SPINLOCK(tls_device_lock);
+
+static void tls_device_free_ctx(struct tls_context *ctx)
+{
+	struct tls_offload_context *offload_ctx = tls_offload_ctx(ctx);
+
+	kfree(offload_ctx);
+	kfree(ctx);
+}
+
+static void tls_device_gc_task(struct work_struct *work)
+{
+	struct tls_context *ctx, *tmp;
+	unsigned long flags;
+	LIST_HEAD(gc_list);
+
+	spin_lock_irqsave(&tls_device_lock, flags);
+	list_splice_init(&tls_device_gc_list, &gc_list);
+	spin_unlock_irqrestore(&tls_device_lock, flags);
+
+	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
+		struct net_device *netdev = ctx->netdev;
+
+		if (netdev) {
+			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+							TLS_OFFLOAD_CTX_DIR_TX);
+			dev_put(netdev);
+		}
+
+		list_del(&ctx->list);
+		tls_device_free_ctx(ctx);
+	}
+}
+
+static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
+{
+	unsigned long flags;
+
+	spin_lock_irqsave(&tls_device_lock, flags);
+	list_move_tail(&ctx->list, &tls_device_gc_list);
+
+	/* schedule_work inside the spinlock
+	 * to make sure tls_device_down waits for that work.
+	 */
+	schedule_work(&tls_device_gc_work);
+
+	spin_unlock_irqrestore(&tls_device_lock, flags);
+}
+
+/* We assume that the socket is already connected */
+static struct net_device *get_netdev_for_sock(struct sock *sk)
+{
+	struct dst_entry *dst = sk_dst_get(sk);
+	struct net_device *netdev = NULL;
+
+	if (likely(dst)) {
+		netdev = dst->dev;
+		dev_hold(netdev);
+	}
+
+	dst_release(dst);
+
+	return netdev;
+}
+
+static void destroy_record(struct tls_record_info *record)
+{
+	int nr_frags = record->num_frags;
+	skb_frag_t *frag;
+
+	while (nr_frags-- > 0) {
+		frag = &record->frags[nr_frags];
+		__skb_frag_unref(frag);
+	}
+	kfree(record);
+}
+
+static void delete_all_records(struct tls_offload_context *offload_ctx)
+{
+	struct tls_record_info *info, *temp;
+
+	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
+		list_del(&info->list);
+		destroy_record(info);
+	}
+
+	offload_ctx->retransmit_hint = NULL;
+}
+
+static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_record_info *info, *temp;
+	struct tls_offload_context *ctx;
+	u64 deleted_records = 0;
+	unsigned long flags;
+
+	if (!tls_ctx)
+		return;
+
+	ctx = tls_offload_ctx(tls_ctx);
+
+	spin_lock_irqsave(&ctx->lock, flags);
+	info = ctx->retransmit_hint;
+	if (info && !before(acked_seq, info->end_seq)) {
+		ctx->retransmit_hint = NULL;
+		list_del(&info->list);
+		destroy_record(info);
+		deleted_records++;
+	}
+
+	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
+		if (before(acked_seq, info->end_seq))
+			break;
+		list_del(&info->list);
+
+		destroy_record(info);
+		deleted_records++;
+	}
+
+	ctx->unacked_record_sn += deleted_records;
+	spin_unlock_irqrestore(&ctx->lock, flags);
+}
+
+/* At this point, there should be no references on this
+ * socket and no in-flight SKBs associated with this
+ * socket, so it is safe to free all the resources.
+ */
+void tls_device_sk_destruct(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+
+	if (ctx->open_record)
+		destroy_record(ctx->open_record);
+
+	delete_all_records(ctx);
+	crypto_free_aead(ctx->aead_send);
+	ctx->sk_destruct(sk);
+	clean_acked_data_disable(inet_csk(sk));
+
+	if (refcount_dec_and_test(&tls_ctx->refcount))
+		tls_device_queue_ctx_destruction(tls_ctx);
+}
+EXPORT_SYMBOL(tls_device_sk_destruct);
+
+static void tls_append_frag(struct tls_record_info *record,
+			    struct page_frag *pfrag,
+			    int size)
+{
+	skb_frag_t *frag;
+
+	frag = &record->frags[record->num_frags - 1];
+	if (frag->page.p == pfrag->page &&
+	    frag->page_offset + frag->size == pfrag->offset) {
+		frag->size += size;
+	} else {
+		++frag;
+		frag->page.p = pfrag->page;
+		frag->page_offset = pfrag->offset;
+		frag->size = size;
+		++record->num_frags;
+		get_page(pfrag->page);
+	}
+
+	pfrag->offset += size;
+	record->len += size;
+}
+
+static int tls_push_record(struct sock *sk,
+			   struct tls_context *ctx,
+			   struct tls_offload_context *offload_ctx,
+			   struct tls_record_info *record,
+			   struct page_frag *pfrag,
+			   int flags,
+			   unsigned char record_type)
+{
+	struct tcp_sock *tp = tcp_sk(sk);
+	struct page_frag dummy_tag_frag;
+	skb_frag_t *frag;
+	int i;
+
+	/* fill prepend */
+	frag = &record->frags[0];
+	tls_fill_prepend(ctx,
+			 skb_frag_address(frag),
+			 record->len - ctx->tx.prepend_size,
+			 record_type);
+
+	/* HW doesn't care about the data in the tag, because it fills it. */
+	dummy_tag_frag.page = skb_frag_page(frag);
+	dummy_tag_frag.offset = 0;
+
+	tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size);
+	record->end_seq = tp->write_seq + record->len;
+	spin_lock_irq(&offload_ctx->lock);
+	list_add_tail(&record->list, &offload_ctx->records_list);
+	spin_unlock_irq(&offload_ctx->lock);
+	offload_ctx->open_record = NULL;
+	set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
+	tls_advance_record_sn(sk, &ctx->tx);
+
+	for (i = 0; i < record->num_frags; i++) {
+		frag = &record->frags[i];
+		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
+		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
+			    frag->size, frag->page_offset);
+		sk_mem_charge(sk, frag->size);
+		get_page(skb_frag_page(frag));
+	}
+	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
+
+	/* all ready, send */
+	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
+}
+
+static int tls_create_new_record(struct tls_offload_context *offload_ctx,
+				 struct page_frag *pfrag,
+				 size_t prepend_size)
+{
+	struct tls_record_info *record;
+	skb_frag_t *frag;
+
+	record = kmalloc(sizeof(*record), GFP_KERNEL);
+	if (!record)
+		return -ENOMEM;
+
+	frag = &record->frags[0];
+	__skb_frag_set_page(frag, pfrag->page);
+	frag->page_offset = pfrag->offset;
+	skb_frag_size_set(frag, prepend_size);
+
+	get_page(pfrag->page);
+	pfrag->offset += prepend_size;
+
+	record->num_frags = 1;
+	record->len = prepend_size;
+	offload_ctx->open_record = record;
+	return 0;
+}
+
+static int tls_do_allocation(struct sock *sk,
+			     struct tls_offload_context *offload_ctx,
+			     struct page_frag *pfrag,
+			     size_t prepend_size)
+{
+	int ret;
+
+	if (!offload_ctx->open_record) {
+		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
+						   sk->sk_allocation))) {
+			sk->sk_prot->enter_memory_pressure(sk);
+			sk_stream_moderate_sndbuf(sk);
+			return -ENOMEM;
+		}
+
+		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
+		if (ret)
+			return ret;
+
+		if (pfrag->size > pfrag->offset)
+			return 0;
+	}
+
+	if (!sk_page_frag_refill(sk, pfrag))
+		return -ENOMEM;
+
+	return 0;
+}
+
+static int tls_push_data(struct sock *sk,
+			 struct iov_iter *msg_iter,
+			 size_t size, int flags,
+			 unsigned char record_type)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+	int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
+	int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
+	struct tls_record_info *record = ctx->open_record;
+	struct page_frag *pfrag;
+	size_t orig_size = size;
+	u32 max_open_record_len;
+	int copy, rc = 0;
+	bool done = false;
+	long timeo;
+
+	if (flags &
+	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
+		return -ENOTSUPP;
+
+	if (sk->sk_err)
+		return -sk->sk_err;
+
+	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+	rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
+	if (rc < 0)
+		return rc;
+
+	pfrag = sk_page_frag(sk);
+
+	/* TLS_HEADER_SIZE is not counted as part of the TLS record, and
+	 * we need to leave room for an authentication tag.
+	 */
+	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
+			      tls_ctx->tx.prepend_size;
+	do {
+		rc = tls_do_allocation(sk, ctx, pfrag,
+				       tls_ctx->tx.prepend_size);
+		if (rc) {
+			rc = sk_stream_wait_memory(sk, &timeo);
+			if (!rc)
+				continue;
+
+			record = ctx->open_record;
+			if (!record)
+				break;
+handle_error:
+			if (record_type != TLS_RECORD_TYPE_DATA) {
+				/* avoid sending partial
+				 * record with type !=
+				 * application_data
+				 */
+				size = orig_size;
+				destroy_record(record);
+				ctx->open_record = NULL;
+			} else if (record->len > tls_ctx->tx.prepend_size) {
+				goto last_record;
+			}
+
+			break;
+		}
+
+		record = ctx->open_record;
+		copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
+		copy = min_t(size_t, copy, (max_open_record_len - record->len));
+
+		if (copy_from_iter_nocache(page_address(pfrag->page) +
+					       pfrag->offset,
+					   copy, msg_iter) != copy) {
+			rc = -EFAULT;
+			goto handle_error;
+		}
+		tls_append_frag(record, pfrag, copy);
+
+		size -= copy;
+		if (!size) {
+last_record:
+			tls_push_record_flags = flags;
+			if (more) {
+				tls_ctx->pending_open_record_frags =
+						record->num_frags;
+				break;
+			}
+
+			done = true;
+		}
+
+		if (done || record->len >= max_open_record_len ||
+		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
+			rc = tls_push_record(sk,
+					     tls_ctx,
+					     ctx,
+					     record,
+					     pfrag,
+					     tls_push_record_flags,
+					     record_type);
+			if (rc < 0)
+				break;
+		}
+	} while (!done);
+
+	if (orig_size - size > 0)
+		rc = orig_size - size;
+
+	return rc;
+}
+
+int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	int rc;
+
+	lock_sock(sk);
+
+	if (unlikely(msg->msg_controllen)) {
+		rc = tls_proccess_cmsg(sk, msg, &record_type);
+		if (rc)
+			goto out;
+	}
+
+	rc = tls_push_data(sk, &msg->msg_iter, size,
+			   msg->msg_flags, record_type);
+
+out:
+	release_sock(sk);
+	return rc;
+}
+
+int tls_device_sendpage(struct sock *sk, struct page *page,
+			int offset, size_t size, int flags)
+{
+	struct iov_iter	msg_iter;
+	char *kaddr = kmap(page);
+	struct kvec iov;
+	int rc;
+
+	if (flags & MSG_SENDPAGE_NOTLAST)
+		flags |= MSG_MORE;
+
+	lock_sock(sk);
+
+	if (flags & MSG_OOB) {
+		rc = -ENOTSUPP;
+		goto out;
+	}
+
+	iov.iov_base = kaddr + offset;
+	iov.iov_len = size;
+	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
+	rc = tls_push_data(sk, &msg_iter, size,
+			   flags, TLS_RECORD_TYPE_DATA);
+	kunmap(page);
+
+out:
+	release_sock(sk);
+	return rc;
+}
+
+struct tls_record_info *tls_get_record(struct tls_offload_context *context,
+				       u32 seq, u64 *p_record_sn)
+{
+	u64 record_sn = context->hint_record_sn;
+	struct tls_record_info *info;
+
+	info = context->retransmit_hint;
+	if (!info ||
+	    before(seq, info->end_seq - info->len)) {
+		/* if retransmit_hint is irrelevant start
+		 * from the beggining of the list
+		 */
+		info = list_first_entry(&context->records_list,
+					struct tls_record_info, list);
+		record_sn = context->unacked_record_sn;
+	}
+
+	list_for_each_entry_from(info, &context->records_list, list) {
+		if (before(seq, info->end_seq)) {
+			if (!context->retransmit_hint ||
+			    after(info->end_seq,
+				  context->retransmit_hint->end_seq)) {
+				context->hint_record_sn = record_sn;
+				context->retransmit_hint = info;
+			}
+			*p_record_sn = record_sn;
+			return info;
+		}
+		record_sn++;
+	}
+
+	return NULL;
+}
+EXPORT_SYMBOL(tls_get_record);
+
+static int tls_device_push_pending_record(struct sock *sk, int flags)
+{
+	struct iov_iter	msg_iter;
+
+	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
+	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
+}
+
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
+{
+	u16 nonce_size, tag_size, iv_size, rec_seq_size;
+	struct tls_record_info *start_marker_record;
+	struct tls_offload_context *offload_ctx;
+	struct tls_crypto_info *crypto_info;
+	struct net_device *netdev;
+	char *iv, *rec_seq;
+	struct sk_buff *skb;
+	int rc = -EINVAL;
+	__be64 rcd_sn;
+
+	if (!ctx)
+		goto out;
+
+	if (ctx->priv_ctx_tx) {
+		rc = -EEXIST;
+		goto out;
+	}
+
+	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
+	if (!start_marker_record) {
+		rc = -ENOMEM;
+		goto out;
+	}
+
+	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL);
+	if (!offload_ctx) {
+		rc = -ENOMEM;
+		goto free_marker_record;
+	}
+
+	crypto_info = &ctx->crypto_send;
+	switch (crypto_info->cipher_type) {
+	case TLS_CIPHER_AES_GCM_128:
+		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
+		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
+		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
+		rec_seq =
+		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
+		break;
+	default:
+		rc = -EINVAL;
+		goto free_offload_ctx;
+	}
+
+	ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
+	ctx->tx.tag_size = tag_size;
+	ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
+	ctx->tx.iv_size = iv_size;
+	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+			     GFP_KERNEL);
+	if (!ctx->tx.iv) {
+		rc = -ENOMEM;
+		goto free_offload_ctx;
+	}
+
+	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+
+	ctx->tx.rec_seq_size = rec_seq_size;
+	ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+	if (!ctx->tx.rec_seq) {
+		rc = -ENOMEM;
+		goto free_iv;
+	}
+	memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size);
+
+	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
+	if (rc)
+		goto free_rec_seq;
+
+	/* start at rec_seq - 1 to account for the start marker record */
+	memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
+	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
+
+	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
+	start_marker_record->len = 0;
+	start_marker_record->num_frags = 0;
+
+	INIT_LIST_HEAD(&offload_ctx->records_list);
+	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
+	spin_lock_init(&offload_ctx->lock);
+
+	clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
+	ctx->push_pending_record = tls_device_push_pending_record;
+	offload_ctx->sk_destruct = sk->sk_destruct;
+
+	/* TLS offload is greatly simplified if we don't send
+	 * SKBs where only part of the payload needs to be encrypted.
+	 * So mark the last skb in the write queue as end of record.
+	 */
+	skb = tcp_write_queue_tail(sk);
+	if (skb)
+		TCP_SKB_CB(skb)->eor = 1;
+
+	refcount_set(&ctx->refcount, 1);
+
+	/* We support starting offload on multiple sockets
+	 * concurrently, so we only need a read lock here.
+	 * This lock must precede get_netdev_for_sock to prevent races between
+	 * NETDEV_DOWN and setsockopt.
+	 */
+	down_read(&device_offload_lock);
+	netdev = get_netdev_for_sock(sk);
+	if (!netdev) {
+		pr_err_ratelimited("%s: netdev not found\n", __func__);
+		rc = -EINVAL;
+		goto release_lock;
+	}
+
+	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
+		rc = -ENOTSUPP;
+		goto release_netdev;
+	}
+
+	/* Avoid offloading if the device is down
+	 * We don't want to offload new flows after
+	 * the NETDEV_DOWN event
+	 */
+	if (!(netdev->flags & IFF_UP)) {
+		rc = -EINVAL;
+		goto release_netdev;
+	}
+
+	ctx->priv_ctx_tx = offload_ctx;
+	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
+					     &ctx->crypto_send,
+					     tcp_sk(sk)->write_seq);
+	if (rc)
+		goto release_netdev;
+
+	ctx->netdev = netdev;
+
+	spin_lock_irq(&tls_device_lock);
+	list_add_tail(&ctx->list, &tls_device_list);
+	spin_unlock_irq(&tls_device_lock);
+
+	sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
+	/* following this assignment tls_is_sk_tx_device_offloaded
+	 * will return true and the context might be accessed
+	 * by the netdev's xmit function.
+	 */
+	smp_store_release(&sk->sk_destruct,
+			  &tls_device_sk_destruct);
+	up_read(&device_offload_lock);
+	goto out;
+
+release_netdev:
+	dev_put(netdev);
+release_lock:
+	up_read(&device_offload_lock);
+	clean_acked_data_disable(inet_csk(sk));
+	crypto_free_aead(offload_ctx->aead_send);
+free_rec_seq:
+	kfree(ctx->tx.rec_seq);
+free_iv:
+	kfree(ctx->tx.iv);
+free_offload_ctx:
+	kfree(offload_ctx);
+	ctx->priv_ctx_tx = NULL;
+free_marker_record:
+	kfree(start_marker_record);
+out:
+	return rc;
+}
+
+static int tls_device_down(struct net_device *netdev)
+{
+	struct tls_context *ctx, *tmp;
+	unsigned long flags;
+	LIST_HEAD(list);
+
+	/* Request a write lock to block new offload attempts */
+	down_write(&device_offload_lock);
+
+	spin_lock_irqsave(&tls_device_lock, flags);
+	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
+		if (ctx->netdev != netdev ||
+		    !refcount_inc_not_zero(&ctx->refcount))
+			continue;
+
+		list_move(&ctx->list, &list);
+	}
+	spin_unlock_irqrestore(&tls_device_lock, flags);
+
+	list_for_each_entry_safe(ctx, tmp, &list, list)	{
+		netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
+						TLS_OFFLOAD_CTX_DIR_TX);
+		ctx->netdev = NULL;
+		dev_put(netdev);
+		list_del_init(&ctx->list);
+
+		if (refcount_dec_and_test(&ctx->refcount))
+			tls_device_free_ctx(ctx);
+	}
+
+	up_write(&device_offload_lock);
+
+	flush_work(&tls_device_gc_work);
+
+	return NOTIFY_DONE;
+}
+
+static int tls_dev_event(struct notifier_block *this, unsigned long event,
+			 void *ptr)
+{
+	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
+
+	if (!(dev->features & NETIF_F_HW_TLS_TX))
+		return NOTIFY_DONE;
+
+	switch (event) {
+	case NETDEV_REGISTER:
+	case NETDEV_FEAT_CHANGE:
+		if  (dev->tlsdev_ops &&
+		     dev->tlsdev_ops->tls_dev_add &&
+		     dev->tlsdev_ops->tls_dev_del)
+			return NOTIFY_DONE;
+		else
+			return NOTIFY_BAD;
+	case NETDEV_DOWN:
+		return tls_device_down(dev);
+	}
+	return NOTIFY_DONE;
+}
+
+static struct notifier_block tls_dev_notifier = {
+	.notifier_call	= tls_dev_event,
+};
+
+void __init tls_device_init(void)
+{
+	register_netdevice_notifier(&tls_dev_notifier);
+}
+
+void __exit tls_device_cleanup(void)
+{
+	unregister_netdevice_notifier(&tls_dev_notifier);
+	flush_work(&tls_device_gc_work);
+}
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
new file mode 100644
index 000000000000..748914abdb60
--- /dev/null
+++ b/net/tls/tls_device_fallback.c
@@ -0,0 +1,450 @@
+/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses.  You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the
+ * OpenIB.org BSD license below:
+ *
+ *     Redistribution and use in source and binary forms, with or
+ *     without modification, are permitted provided that the following
+ *     conditions are met:
+ *
+ *      - Redistributions of source code must retain the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer.
+ *
+ *      - Redistributions in binary form must reproduce the above
+ *        copyright notice, this list of conditions and the following
+ *        disclaimer in the documentation and/or other materials
+ *        provided with the distribution.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include <net/tls.h>
+#include <crypto/aead.h>
+#include <crypto/scatterwalk.h>
+#include <net/ip6_checksum.h>
+
+static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
+{
+	struct scatterlist *src = walk->sg;
+	int diff = walk->offset - src->offset;
+
+	sg_set_page(sg, sg_page(src),
+		    src->length - diff, walk->offset);
+
+	scatterwalk_crypto_chain(sg, sg_next(src), 0, 2);
+}
+
+static int tls_enc_record(struct aead_request *aead_req,
+			  struct crypto_aead *aead, char *aad,
+			  char *iv, __be64 rcd_sn,
+			  struct scatter_walk *in,
+			  struct scatter_walk *out, int *in_len)
+{
+	unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
+	struct scatterlist sg_in[3];
+	struct scatterlist sg_out[3];
+	u16 len;
+	int rc;
+
+	len = min_t(int, *in_len, ARRAY_SIZE(buf));
+
+	scatterwalk_copychunks(buf, in, len, 0);
+	scatterwalk_copychunks(buf, out, len, 1);
+
+	*in_len -= len;
+	if (!*in_len)
+		return 0;
+
+	scatterwalk_pagedone(in, 0, 1);
+	scatterwalk_pagedone(out, 1, 1);
+
+	len = buf[4] | (buf[3] << 8);
+	len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
+
+	tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
+		     (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
+
+	memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
+	       TLS_CIPHER_AES_GCM_128_IV_SIZE);
+
+	sg_init_table(sg_in, ARRAY_SIZE(sg_in));
+	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
+	sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
+	sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
+	chain_to_walk(sg_in + 1, in);
+	chain_to_walk(sg_out + 1, out);
+
+	*in_len -= len;
+	if (*in_len < 0) {
+		*in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+		/* the input buffer doesn't contain the entire record.
+		 * trim len accordingly. The resulting authentication tag
+		 * will contain garbage, but we don't care, so we won't
+		 * include any of it in the output skb
+		 * Note that we assume the output buffer length
+		 * is larger then input buffer length + tag size
+		 */
+		if (*in_len < 0)
+			len += *in_len;
+
+		*in_len = 0;
+	}
+
+	if (*in_len) {
+		scatterwalk_copychunks(NULL, in, len, 2);
+		scatterwalk_pagedone(in, 0, 1);
+		scatterwalk_copychunks(NULL, out, len, 2);
+		scatterwalk_pagedone(out, 1, 1);
+	}
+
+	len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+	aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
+
+	rc = crypto_aead_encrypt(aead_req);
+
+	return rc;
+}
+
+static void tls_init_aead_request(struct aead_request *aead_req,
+				  struct crypto_aead *aead)
+{
+	aead_request_set_tfm(aead_req, aead);
+	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+}
+
+static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
+						   gfp_t flags)
+{
+	unsigned int req_size = sizeof(struct aead_request) +
+		crypto_aead_reqsize(aead);
+	struct aead_request *aead_req;
+
+	aead_req = kzalloc(req_size, flags);
+	if (aead_req)
+		tls_init_aead_request(aead_req, aead);
+	return aead_req;
+}
+
+static int tls_enc_records(struct aead_request *aead_req,
+			   struct crypto_aead *aead, struct scatterlist *sg_in,
+			   struct scatterlist *sg_out, char *aad, char *iv,
+			   u64 rcd_sn, int len)
+{
+	struct scatter_walk out, in;
+	int rc;
+
+	scatterwalk_start(&in, sg_in);
+	scatterwalk_start(&out, sg_out);
+
+	do {
+		rc = tls_enc_record(aead_req, aead, aad, iv,
+				    cpu_to_be64(rcd_sn), &in, &out, &len);
+		rcd_sn++;
+
+	} while (rc == 0 && len);
+
+	scatterwalk_done(&in, 0, 0);
+	scatterwalk_done(&out, 1, 0);
+
+	return rc;
+}
+
+/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses
+ * might have been changed by NAT.
+ */
+static void update_chksum(struct sk_buff *skb, int headln)
+{
+	struct tcphdr *th = tcp_hdr(skb);
+	int datalen = skb->len - headln;
+	const struct ipv6hdr *ipv6h;
+	const struct iphdr *iph;
+
+	/* We only changed the payload so if we are using partial we don't
+	 * need to update anything.
+	 */
+	if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
+		return;
+
+	skb->ip_summed = CHECKSUM_PARTIAL;
+	skb->csum_start = skb_transport_header(skb) - skb->head;
+	skb->csum_offset = offsetof(struct tcphdr, check);
+
+	if (skb->sk->sk_family == AF_INET6) {
+		ipv6h = ipv6_hdr(skb);
+		th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
+					     datalen, IPPROTO_TCP, 0);
+	} else {
+		iph = ip_hdr(skb);
+		th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
+					       IPPROTO_TCP, 0);
+	}
+}
+
+static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
+{
+	skb_copy_header(nskb, skb);
+
+	skb_put(nskb, skb->len);
+	memcpy(nskb->data, skb->data, headln);
+	update_chksum(nskb, headln);
+
+	nskb->destructor = skb->destructor;
+	nskb->sk = skb->sk;
+	skb->destructor = NULL;
+	skb->sk = NULL;
+	refcount_add(nskb->truesize - skb->truesize,
+		     &nskb->sk->sk_wmem_alloc);
+}
+
+/* This function may be called after the user socket is already
+ * closed so make sure we don't use anything freed during
+ * tls_sk_proto_close here
+ */
+
+static int fill_sg_in(struct scatterlist *sg_in,
+		      struct sk_buff *skb,
+		      struct tls_offload_context *ctx,
+		      u64 *rcd_sn,
+		      s32 *sync_size,
+		      int *resync_sgs)
+{
+	int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
+	int payload_len = skb->len - tcp_payload_offset;
+	u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
+	struct tls_record_info *record;
+	unsigned long flags;
+	int remaining;
+	int i;
+
+	spin_lock_irqsave(&ctx->lock, flags);
+	record = tls_get_record(ctx, tcp_seq, rcd_sn);
+	if (!record) {
+		spin_unlock_irqrestore(&ctx->lock, flags);
+		WARN(1, "Record not found for seq %u\n", tcp_seq);
+		return -EINVAL;
+	}
+
+	*sync_size = tcp_seq - tls_record_start_seq(record);
+	if (*sync_size < 0) {
+		int is_start_marker = tls_record_is_start_marker(record);
+
+		spin_unlock_irqrestore(&ctx->lock, flags);
+		/* This should only occur if the relevant record was
+		 * already acked. In that case it should be ok
+		 * to drop the packet and avoid retransmission.
+		 *
+		 * There is a corner case where the packet contains
+		 * both an acked and a non-acked record.
+		 * We currently don't handle that case and rely
+		 * on TCP to retranmit a packet that doesn't contain
+		 * already acked payload.
+		 */
+		if (!is_start_marker)
+			*sync_size = 0;
+		return -EINVAL;
+	}
+
+	remaining = *sync_size;
+	for (i = 0; remaining > 0; i++) {
+		skb_frag_t *frag = &record->frags[i];
+
+		__skb_frag_ref(frag);
+		sg_set_page(sg_in + i, skb_frag_page(frag),
+			    skb_frag_size(frag), frag->page_offset);
+
+		remaining -= skb_frag_size(frag);
+
+		if (remaining < 0)
+			sg_in[i].length += remaining;
+	}
+	*resync_sgs = i;
+
+	spin_unlock_irqrestore(&ctx->lock, flags);
+	if (skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, payload_len) < 0)
+		return -EINVAL;
+
+	return 0;
+}
+
+static void fill_sg_out(struct scatterlist sg_out[3], void *buf,
+			struct tls_context *tls_ctx,
+			struct sk_buff *nskb,
+			int tcp_payload_offset,
+			int payload_len,
+			int sync_size,
+			void *dummy_buf)
+{
+	sg_set_buf(&sg_out[0], dummy_buf, sync_size);
+	sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len);
+	/* Add room for authentication tag produced by crypto */
+	dummy_buf += sync_size;
+	sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+}
+
+static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
+				   struct scatterlist sg_out[3],
+				   struct scatterlist *sg_in,
+				   struct sk_buff *skb,
+				   s32 sync_size, u64 rcd_sn)
+{
+	int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
+	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+	int payload_len = skb->len - tcp_payload_offset;
+	void *buf, *iv, *aad, *dummy_buf;
+	struct aead_request *aead_req;
+	struct sk_buff *nskb = NULL;
+	int buf_len;
+
+	aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
+	if (!aead_req)
+		return NULL;
+
+	buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
+		  TLS_CIPHER_AES_GCM_128_IV_SIZE +
+		  TLS_AAD_SPACE_SIZE +
+		  sync_size +
+		  TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+	buf = kmalloc(buf_len, GFP_ATOMIC);
+	if (!buf)
+		goto free_req;
+
+	iv = buf;
+	memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt,
+	       TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
+	      TLS_CIPHER_AES_GCM_128_IV_SIZE;
+	dummy_buf = aad + TLS_AAD_SPACE_SIZE;
+
+	nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
+	if (!nskb)
+		goto free_buf;
+
+	skb_reserve(nskb, skb_headroom(skb));
+
+	fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset,
+		    payload_len, sync_size, dummy_buf);
+
+	if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
+			    rcd_sn, sync_size + payload_len) < 0)
+		goto free_nskb;
+
+	complete_skb(nskb, skb, tcp_payload_offset);
+
+	/* validate_xmit_skb_list assumes that if the skb wasn't segmented
+	 * nskb->prev will point to the skb itself
+	 */
+	nskb->prev = nskb;
+
+free_buf:
+	kfree(buf);
+free_req:
+	kfree(aead_req);
+	return nskb;
+free_nskb:
+	kfree_skb(nskb);
+	nskb = NULL;
+	goto free_buf;
+}
+
+static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
+{
+	int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
+	int payload_len = skb->len - tcp_payload_offset;
+	struct scatterlist *sg_in, sg_out[3];
+	struct sk_buff *nskb = NULL;
+	int sg_in_max_elements;
+	int resync_sgs = 0;
+	s32 sync_size = 0;
+	u64 rcd_sn;
+
+	/* worst case is:
+	 * MAX_SKB_FRAGS in tls_record_info
+	 * MAX_SKB_FRAGS + 1 in SKB head and frags.
+	 */
+	sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
+
+	if (!payload_len)
+		return skb;
+
+	sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
+	if (!sg_in)
+		goto free_orig;
+
+	sg_init_table(sg_in, sg_in_max_elements);
+	sg_init_table(sg_out, ARRAY_SIZE(sg_out));
+
+	if (fill_sg_in(sg_in, skb, ctx, &rcd_sn, &sync_size, &resync_sgs)) {
+		/* bypass packets before kernel TLS socket option was set */
+		if (sync_size < 0 && payload_len <= -sync_size)
+			nskb = skb_get(skb);
+		goto put_sg;
+	}
+
+	nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn);
+
+put_sg:
+	while (resync_sgs)
+		put_page(sg_page(&sg_in[--resync_sgs]));
+	kfree(sg_in);
+free_orig:
+	kfree_skb(skb);
+	return nskb;
+}
+
+struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
+				      struct net_device *dev,
+				      struct sk_buff *skb)
+{
+	if (dev == tls_get_ctx(sk)->netdev)
+		return skb;
+
+	return tls_sw_fallback(sk, skb);
+}
+
+int tls_sw_fallback_init(struct sock *sk,
+			 struct tls_offload_context *offload_ctx,
+			 struct tls_crypto_info *crypto_info)
+{
+	const u8 *key;
+	int rc;
+
+	offload_ctx->aead_send =
+	    crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
+	if (IS_ERR(offload_ctx->aead_send)) {
+		rc = PTR_ERR(offload_ctx->aead_send);
+		pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
+		offload_ctx->aead_send = NULL;
+		goto err_out;
+	}
+
+	key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+
+	rc = crypto_aead_setkey(offload_ctx->aead_send, key,
+				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+	if (rc)
+		goto free_aead;
+
+	rc = crypto_aead_setauthsize(offload_ctx->aead_send,
+				     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+	if (rc)
+		goto free_aead;
+
+	return 0;
+free_aead:
+	crypto_free_aead(offload_ctx->aead_send);
+err_out:
+	return rc;
+}
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 545bf34ed599..3aafb871a0a8 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -54,6 +54,9 @@ enum {
 enum {
 	TLS_BASE,
 	TLS_SW,
+#ifdef CONFIG_TLS_DEVICE
+	TLS_HW,
+#endif
 	TLS_HW_RECORD,
 	TLS_NUM_CONFIG,
 };
@@ -280,6 +283,15 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
 		tls_sw_free_resources_rx(sk);
 	}
 
+#ifdef CONFIG_TLS_DEVICE
+	if (ctx->tx_conf != TLS_HW) {
+#else
+	{
+#endif
+		kfree(ctx);
+		ctx = NULL;
+	}
+
 skip_tx_cleanup:
 	release_sock(sk);
 	sk_proto_close(sk, timeout);
@@ -442,8 +454,16 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 	}
 
 	if (tx) {
-		rc = tls_set_sw_offload(sk, ctx, 1);
-		conf = TLS_SW;
+#ifdef CONFIG_TLS_DEVICE
+		rc = tls_set_device_offload(sk, ctx);
+		conf = TLS_HW;
+		if (rc) {
+#else
+		{
+#endif
+			rc = tls_set_sw_offload(sk, ctx, 1);
+			conf = TLS_SW;
+		}
 	} else {
 		rc = tls_set_sw_offload(sk, ctx, 0);
 		conf = TLS_SW;
@@ -596,6 +616,16 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 	prot[TLS_SW][TLS_SW].recvmsg	= tls_sw_recvmsg;
 	prot[TLS_SW][TLS_SW].close	= tls_sk_proto_close;
 
+#ifdef CONFIG_TLS_DEVICE
+	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg;
+	prot[TLS_HW][TLS_BASE].sendpage		= tls_device_sendpage;
+
+	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
+	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
+	prot[TLS_HW][TLS_SW].sendpage		= tls_device_sendpage;
+#endif
+
 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
 	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_hw_hash;
 	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_hw_unhash;
@@ -630,7 +660,7 @@ static int tls_init(struct sock *sk)
 	ctx->getsockopt = sk->sk_prot->getsockopt;
 	ctx->sk_proto_close = sk->sk_prot->close;
 
-	/* Build IPv6 TLS whenever the address of tcpv6_prot changes */
+	/* Build IPv6 TLS whenever the address of tcpv6	_prot changes */
 	if (ip_ver == TLSV6 &&
 	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
 		mutex_lock(&tcpv6_prot_mutex);
@@ -680,6 +710,9 @@ static int __init tls_register(void)
 	tls_sw_proto_ops.poll = tls_sw_poll;
 	tls_sw_proto_ops.splice_read = tls_sw_splice_read;
 
+#ifdef CONFIG_TLS_DEVICE
+	tls_device_init();
+#endif
 	tcp_register_ulp(&tcp_tls_ulp_ops);
 
 	return 0;
@@ -688,6 +721,9 @@ static int __init tls_register(void)
 static void __exit tls_unregister(void)
 {
 	tcp_unregister_ulp(&tcp_tls_ulp_ops);
+#ifdef CONFIG_TLS_DEVICE
+	tls_device_cleanup();
+#endif
 }
 
 module_init(tls_register);