summary refs log tree commit diff
path: root/arch/arm64/crypto/aes-glue.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/arm64/crypto/aes-glue.c')
-rw-r--r--arch/arm64/crypto/aes-glue.c217
1 files changed, 193 insertions, 24 deletions
diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index adcb83eb683c..1e676625ef33 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -15,6 +15,7 @@
 #include <crypto/internal/hash.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <linux/module.h>
 #include <linux/cpufeature.h>
 #include <crypto/xts.h>
@@ -31,6 +32,8 @@
 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
 #define aes_xts_encrypt		ce_aes_xts_encrypt
 #define aes_xts_decrypt		ce_aes_xts_decrypt
@@ -45,6 +48,8 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
 #define aes_xts_encrypt		neon_aes_xts_encrypt
 #define aes_xts_decrypt		neon_aes_xts_decrypt
@@ -63,30 +68,41 @@ MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
 MODULE_LICENSE("GPL v2");
 
 /* defined in aes-modes.S */
-asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
 				int rounds, int blocks);
-asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
 				int rounds, int blocks);
 
-asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
 				int rounds, int blocks, u8 iv[]);
-asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
 				int rounds, int blocks, u8 iv[]);
 
-asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
+				int rounds, int bytes, u8 const iv[]);
+asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
+				int rounds, int bytes, u8 const iv[]);
+
+asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
 				int rounds, int blocks, u8 ctr[]);
 
-asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
-				int rounds, int blocks, u8 const rk2[], u8 iv[],
+asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
+				int rounds, int blocks, u32 const rk2[], u8 iv[],
 				int first);
-asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
-				int rounds, int blocks, u8 const rk2[], u8 iv[],
+asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
+				int rounds, int blocks, u32 const rk2[], u8 iv[],
 				int first);
 
 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
 			       int blocks, u8 dg[], int enc_before,
 			       int enc_after);
 
+struct cts_cbc_req_ctx {
+	struct scatterlist sg_src[2];
+	struct scatterlist sg_dst[2];
+	struct skcipher_request subreq;
+};
+
 struct crypto_aes_xts_ctx {
 	struct crypto_aes_ctx key1;
 	struct crypto_aes_ctx __aligned(8) key2;
@@ -142,7 +158,7 @@ static int ecb_encrypt(struct skcipher_request *req)
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 		kernel_neon_begin();
 		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				(u8 *)ctx->key_enc, rounds, blocks);
+				ctx->key_enc, rounds, blocks);
 		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
@@ -162,7 +178,7 @@ static int ecb_decrypt(struct skcipher_request *req)
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 		kernel_neon_begin();
 		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				(u8 *)ctx->key_dec, rounds, blocks);
+				ctx->key_dec, rounds, blocks);
 		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
@@ -182,7 +198,7 @@ static int cbc_encrypt(struct skcipher_request *req)
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 		kernel_neon_begin();
 		aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				(u8 *)ctx->key_enc, rounds, blocks, walk.iv);
+				ctx->key_enc, rounds, blocks, walk.iv);
 		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
@@ -202,13 +218,149 @@ static int cbc_decrypt(struct skcipher_request *req)
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 		kernel_neon_begin();
 		aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				(u8 *)ctx->key_dec, rounds, blocks, walk.iv);
+				ctx->key_dec, rounds, blocks, walk.iv);
 		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
 	return err;
 }
 
+static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
+{
+	crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
+	return 0;
+}
+
+static int cts_cbc_encrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+
+	skcipher_request_set_tfm(&rctx->subreq, tfm);
+
+	if (req->cryptlen <= AES_BLOCK_SIZE) {
+		if (req->cryptlen < AES_BLOCK_SIZE)
+			return -EINVAL;
+		cbc_blocks = 1;
+	}
+
+	if (cbc_blocks > 0) {
+		unsigned int blocks;
+
+		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+
+		err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					ctx->key_enc, rounds, blocks, walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
+					     rctx->subreq.cryptlen);
+		if (req->dst != req->src)
+			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
+					       rctx->subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&rctx->subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
+static int cts_cbc_decrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+
+	skcipher_request_set_tfm(&rctx->subreq, tfm);
+
+	if (req->cryptlen <= AES_BLOCK_SIZE) {
+		if (req->cryptlen < AES_BLOCK_SIZE)
+			return -EINVAL;
+		cbc_blocks = 1;
+	}
+
+	if (cbc_blocks > 0) {
+		unsigned int blocks;
+
+		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+
+		err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					ctx->key_dec, rounds, blocks, walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
+					     rctx->subreq.cryptlen);
+		if (req->dst != req->src)
+			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
+					       rctx->subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&rctx->subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
 static int ctr_encrypt(struct skcipher_request *req)
 {
 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -222,7 +374,7 @@ static int ctr_encrypt(struct skcipher_request *req)
 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 		kernel_neon_begin();
 		aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				(u8 *)ctx->key_enc, rounds, blocks, walk.iv);
+				ctx->key_enc, rounds, blocks, walk.iv);
 		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
@@ -238,7 +390,7 @@ static int ctr_encrypt(struct skcipher_request *req)
 		blocks = -1;
 
 		kernel_neon_begin();
-		aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc, rounds,
+		aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
 				blocks, walk.iv);
 		kernel_neon_end();
 		crypto_xor_cpy(tdst, tsrc, tail, nbytes);
@@ -272,8 +424,8 @@ static int xts_encrypt(struct skcipher_request *req)
 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 		kernel_neon_begin();
 		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				(u8 *)ctx->key1.key_enc, rounds, blocks,
-				(u8 *)ctx->key2.key_enc, walk.iv, first);
+				ctx->key1.key_enc, rounds, blocks,
+				ctx->key2.key_enc, walk.iv, first);
 		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
@@ -294,8 +446,8 @@ static int xts_decrypt(struct skcipher_request *req)
 	for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 		kernel_neon_begin();
 		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-				(u8 *)ctx->key1.key_dec, rounds, blocks,
-				(u8 *)ctx->key2.key_enc, walk.iv, first);
+				ctx->key1.key_dec, rounds, blocks,
+				ctx->key2.key_enc, walk.iv, first);
 		kernel_neon_end();
 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 	}
@@ -336,6 +488,24 @@ static struct skcipher_alg aes_algs[] = { {
 	.decrypt	= cbc_decrypt,
 }, {
 	.base = {
+		.cra_name		= "__cts(cbc(aes))",
+		.cra_driver_name	= "__cts-cbc-aes-" MODE,
+		.cra_priority		= PRIO,
+		.cra_flags		= CRYPTO_ALG_INTERNAL,
+		.cra_blocksize		= AES_BLOCK_SIZE,
+		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
+		.cra_module		= THIS_MODULE,
+	},
+	.min_keysize	= AES_MIN_KEY_SIZE,
+	.max_keysize	= AES_MAX_KEY_SIZE,
+	.ivsize		= AES_BLOCK_SIZE,
+	.walksize	= 2 * AES_BLOCK_SIZE,
+	.setkey		= skcipher_aes_setkey,
+	.encrypt	= cts_cbc_encrypt,
+	.decrypt	= cts_cbc_decrypt,
+	.init		= cts_cbc_init_tfm,
+}, {
+	.base = {
 		.cra_name		= "__ctr(aes)",
 		.cra_driver_name	= "__ctr-aes-" MODE,
 		.cra_priority		= PRIO,
@@ -412,7 +582,6 @@ static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 {
 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 	be128 *consts = (be128 *)ctx->consts;
-	u8 *rk = (u8 *)ctx->key.key_enc;
 	int rounds = 6 + key_len / 4;
 	int err;
 
@@ -422,7 +591,8 @@ static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 
 	/* encrypt the zero vector */
 	kernel_neon_begin();
-	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, rk, rounds, 1);
+	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
+			rounds, 1);
 	kernel_neon_end();
 
 	cmac_gf128_mul_by_x(consts, consts);
@@ -441,7 +611,6 @@ static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 	};
 
 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
-	u8 *rk = (u8 *)ctx->key.key_enc;
 	int rounds = 6 + key_len / 4;
 	u8 key[AES_BLOCK_SIZE];
 	int err;
@@ -451,8 +620,8 @@ static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 		return err;
 
 	kernel_neon_begin();
-	aes_ecb_encrypt(key, ks[0], rk, rounds, 1);
-	aes_ecb_encrypt(ctx->consts, ks[1], rk, rounds, 2);
+	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
+	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
 	kernel_neon_end();
 
 	return cbcmac_setkey(tfm, key, sizeof(key));