summary refs log tree commit diff
path: root/arch
diff options
context:
space:
mode:
authorRussell King <rmk+kernel@armlinux.org.uk>2018-01-13 22:38:18 +0000
committerRussell King <rmk+kernel@armlinux.org.uk>2018-01-17 19:38:07 +0000
commit02088d9b392f605c892894b46aa8c83e3abd0115 (patch)
tree7dd2b1eb8b47f46e6d621a81a745d4158a002182 /arch
parent0005e55a79cfda88199e41a406a829c88d708c67 (diff)
downloadlinux-02088d9b392f605c892894b46aa8c83e3abd0115.tar.gz
ARM: net: bpf: fix register saving
When an eBPF program tail-calls another eBPF program, it enters it after
the prologue to avoid having complex stack manipulations.  This can lead
to kernel oopses, and similar.

Resolve this by always using a fixed stack layout, a CPU register frame
pointer, and using this when reloading registers before returning.

Fixes: 39c13c204bb1 ("arm: eBPF JIT compiler")
Signed-off-by: Russell King <rmk+kernel@armlinux.org.uk>
Diffstat (limited to 'arch')
-rw-r--r--arch/arm/net/bpf_jit_32.c80
1 files changed, 22 insertions, 58 deletions
diff --git a/arch/arm/net/bpf_jit_32.c b/arch/arm/net/bpf_jit_32.c
index dcb3181e85f3..95bb3f896c8f 100644
--- a/arch/arm/net/bpf_jit_32.c
+++ b/arch/arm/net/bpf_jit_32.c
@@ -61,20 +61,24 @@ int bpf_jit_enable __read_mostly;
  *
  *                                high
  * original ARM_SP =>     +------------------+
- *                        |        lr        | (optional)
- *                        |     r4-r8,r10    | callee saved registers
- *                        +------------------+
+ *                        | r4-r8,r10,fp,lr  | callee saved registers
+ * current ARM_FP =>      +------------------+
  *                                low
+ *
+ * When popping registers off the stack at the end of a BPF function, we
+ * reference them via the current ARM_FP register.
  */
+#define CALLEE_MASK	(1 << ARM_R4 | 1 << ARM_R5 | 1 << ARM_R6 | \
+			 1 << ARM_R7 | 1 << ARM_R8 | 1 << ARM_R10 | \
+			 1 << ARM_FP)
+#define CALLEE_PUSH_MASK (CALLEE_MASK | 1 << ARM_LR)
+#define CALLEE_POP_MASK  (CALLEE_MASK | 1 << ARM_PC)
 
 #define STACK_OFFSET(k)	(k)
 #define TMP_REG_1	(MAX_BPF_JIT_REG + 0)	/* TEMP Register 1 */
 #define TMP_REG_2	(MAX_BPF_JIT_REG + 1)	/* TEMP Register 2 */
 #define TCALL_CNT	(MAX_BPF_JIT_REG + 2)	/* Tail Call Count */
 
-/* Flags used for JIT optimization */
-#define SEEN_CALL	(1 << 0)
-
 #define FLAG_IMM_OVERFLOW	(1 << 0)
 
 /*
@@ -135,7 +139,6 @@ static const u8 bpf2a32[][2] = {
  * idx			:	index of current last JITed instruction.
  * prologue_bytes	:	bytes used in prologue.
  * epilogue_offset	:	offset of epilogue starting.
- * seen			:	bit mask used for JIT optimization.
  * offsets		:	array of eBPF instruction offsets in
  *				JITed code.
  * target		:	final JITed code.
@@ -150,7 +153,6 @@ struct jit_ctx {
 	unsigned int idx;
 	unsigned int prologue_bytes;
 	unsigned int epilogue_offset;
-	u32 seen;
 	u32 flags;
 	u32 *offsets;
 	u32 *target;
@@ -340,7 +342,6 @@ static void emit_bx_r(u8 tgt_reg, struct jit_ctx *ctx)
 
 static inline void emit_blx_r(u8 tgt_reg, struct jit_ctx *ctx)
 {
-	ctx->seen |= SEEN_CALL;
 #if __LINUX_ARM_ARCH__ < 5
 	emit(ARM_MOV_R(ARM_LR, ARM_PC), ctx);
 	emit_bx_r(tgt_reg, ctx);
@@ -403,7 +404,6 @@ static inline void emit_udivmod(u8 rd, u8 rm, u8 rn, struct jit_ctx *ctx, u8 op)
 	}
 
 	/* Call appropriate function */
-	ctx->seen |= SEEN_CALL;
 	emit_mov_i(ARM_IP, op == BPF_DIV ?
 		   (u32)jit_udiv32 : (u32)jit_mod32, ctx);
 	emit_blx_r(ARM_IP, ctx);
@@ -669,8 +669,6 @@ static inline void emit_a32_lsh_r64(const u8 dst[], const u8 src[], bool dstk,
 	/* Do LSH operation */
 	emit(ARM_SUB_I(ARM_IP, rt, 32), ctx);
 	emit(ARM_RSB_I(tmp2[0], rt, 32), ctx);
-	/* As we are using ARM_LR */
-	ctx->seen |= SEEN_CALL;
 	emit(ARM_MOV_SR(ARM_LR, rm, SRTYPE_ASL, rt), ctx);
 	emit(ARM_ORR_SR(ARM_LR, ARM_LR, rd, SRTYPE_ASL, ARM_IP), ctx);
 	emit(ARM_ORR_SR(ARM_IP, ARM_LR, rd, SRTYPE_LSR, tmp2[0]), ctx);
@@ -705,8 +703,6 @@ static inline void emit_a32_arsh_r64(const u8 dst[], const u8 src[], bool dstk,
 	/* Do the ARSH operation */
 	emit(ARM_RSB_I(ARM_IP, rt, 32), ctx);
 	emit(ARM_SUBS_I(tmp2[0], rt, 32), ctx);
-	/* As we are using ARM_LR */
-	ctx->seen |= SEEN_CALL;
 	emit(ARM_MOV_SR(ARM_LR, rd, SRTYPE_LSR, rt), ctx);
 	emit(ARM_ORR_SR(ARM_LR, ARM_LR, rm, SRTYPE_ASL, ARM_IP), ctx);
 	_emit(ARM_COND_MI, ARM_B(0), ctx);
@@ -741,8 +737,6 @@ static inline void emit_a32_lsr_r64(const u8 dst[], const u8 src[], bool dstk,
 	/* Do LSH operation */
 	emit(ARM_RSB_I(ARM_IP, rt, 32), ctx);
 	emit(ARM_SUBS_I(tmp2[0], rt, 32), ctx);
-	/* As we are using ARM_LR */
-	ctx->seen |= SEEN_CALL;
 	emit(ARM_MOV_SR(ARM_LR, rd, SRTYPE_LSR, rt), ctx);
 	emit(ARM_ORR_SR(ARM_LR, ARM_LR, rm, SRTYPE_ASL, ARM_IP), ctx);
 	emit(ARM_ORR_SR(ARM_LR, ARM_LR, rm, SRTYPE_LSR, tmp2[0]), ctx);
@@ -877,8 +871,6 @@ static inline void emit_a32_mul_r64(const u8 dst[], const u8 src[], bool dstk,
 	/* Do Multiplication */
 	emit(ARM_MUL(ARM_IP, rd, rn), ctx);
 	emit(ARM_MUL(ARM_LR, rm, rt), ctx);
-	/* As we are using ARM_LR */
-	ctx->seen |= SEEN_CALL;
 	emit(ARM_ADD_R(ARM_LR, ARM_IP, ARM_LR), ctx);
 
 	emit(ARM_UMULL(ARM_IP, rm, rd, rt), ctx);
@@ -955,7 +947,6 @@ static inline void emit_ar_r(const u8 rd, const u8 rt, const u8 rm,
 			     const u8 rn, struct jit_ctx *ctx, u8 op) {
 	switch (op) {
 	case BPF_JSET:
-		ctx->seen |= SEEN_CALL;
 		emit(ARM_AND_R(ARM_IP, rt, rn), ctx);
 		emit(ARM_AND_R(ARM_LR, rd, rm), ctx);
 		emit(ARM_ORRS_R(ARM_IP, ARM_LR, ARM_IP), ctx);
@@ -1119,33 +1110,22 @@ static void build_prologue(struct jit_ctx *ctx)
 	const u8 r2 = bpf2a32[BPF_REG_1][1];
 	const u8 r3 = bpf2a32[BPF_REG_1][0];
 	const u8 r4 = bpf2a32[BPF_REG_6][1];
-	const u8 r5 = bpf2a32[BPF_REG_6][0];
-	const u8 r6 = bpf2a32[TMP_REG_1][1];
-	const u8 r7 = bpf2a32[TMP_REG_1][0];
-	const u8 r8 = bpf2a32[TMP_REG_2][1];
-	const u8 r10 = bpf2a32[TMP_REG_2][0];
 	const u8 fplo = bpf2a32[BPF_REG_FP][1];
 	const u8 fphi = bpf2a32[BPF_REG_FP][0];
-	const u8 sp = ARM_SP;
 	const u8 *tcc = bpf2a32[TCALL_CNT];
 
-	u16 reg_set = 0;
-
 	/* Save callee saved registers. */
-	reg_set |= (1<<r4) | (1<<r5) | (1<<r6) | (1<<r7) | (1<<r8) | (1<<r10);
 #ifdef CONFIG_FRAME_POINTER
-	reg_set |= (1<<ARM_FP) | (1<<ARM_IP) | (1<<ARM_LR) | (1<<ARM_PC);
-	emit(ARM_MOV_R(ARM_IP, sp), ctx);
+	u16 reg_set = CALLEE_PUSH_MASK | 1 << ARM_IP | 1 << ARM_PC;
+	emit(ARM_MOV_R(ARM_IP, ARM_SP), ctx);
 	emit(ARM_PUSH(reg_set), ctx);
 	emit(ARM_SUB_I(ARM_FP, ARM_IP, 4), ctx);
 #else
-	/* Check if call instruction exists in BPF body */
-	if (ctx->seen & SEEN_CALL)
-		reg_set |= (1<<ARM_LR);
-	emit(ARM_PUSH(reg_set), ctx);
+	emit(ARM_PUSH(CALLEE_PUSH_MASK), ctx);
+	emit(ARM_MOV_R(ARM_FP, ARM_SP), ctx);
 #endif
 	/* Save frame pointer for later */
-	emit(ARM_SUB_I(ARM_IP, sp, SCRATCH_SIZE), ctx);
+	emit(ARM_SUB_I(ARM_IP, ARM_SP, SCRATCH_SIZE), ctx);
 
 	ctx->stack_size = imm8m(STACK_SIZE);
 
@@ -1168,33 +1148,19 @@ static void build_prologue(struct jit_ctx *ctx)
 	/* end of prologue */
 }
 
+/* restore callee saved registers. */
 static void build_epilogue(struct jit_ctx *ctx)
 {
-	const u8 r4 = bpf2a32[BPF_REG_6][1];
-	const u8 r5 = bpf2a32[BPF_REG_6][0];
-	const u8 r6 = bpf2a32[TMP_REG_1][1];
-	const u8 r7 = bpf2a32[TMP_REG_1][0];
-	const u8 r8 = bpf2a32[TMP_REG_2][1];
-	const u8 r10 = bpf2a32[TMP_REG_2][0];
-	u16 reg_set = 0;
-
-	/* unwind function call stack */
-	emit(ARM_ADD_I(ARM_SP, ARM_SP, ctx->stack_size), ctx);
-
-	/* restore callee saved registers. */
-	reg_set |= (1<<r4) | (1<<r5) | (1<<r6) | (1<<r7) | (1<<r8) | (1<<r10);
 #ifdef CONFIG_FRAME_POINTER
-	/* the first instruction of the prologue was: mov ip, sp */
-	reg_set |= (1<<ARM_FP) | (1<<ARM_SP) | (1<<ARM_PC);
+	/* When using frame pointers, some additional registers need to
+	 * be loaded. */
+	u16 reg_set = CALLEE_POP_MASK | 1 << ARM_SP;
+	emit(ARM_SUB_I(ARM_SP, ARM_FP, hweight16(reg_set) * 4), ctx);
 	emit(ARM_LDM(ARM_SP, reg_set), ctx);
 #else
-	if (ctx->seen & SEEN_CALL)
-		reg_set |= (1<<ARM_PC);
 	/* Restore callee saved registers. */
-	emit(ARM_POP(reg_set), ctx);
-	/* Return back to the callee function */
-	if (!(ctx->seen & SEEN_CALL))
-		emit_bx_r(ARM_LR, ctx);
+	emit(ARM_MOV_R(ARM_SP, ARM_FP), ctx);
+	emit(ARM_POP(CALLEE_POP_MASK), ctx);
 #endif
 }
 
@@ -1422,8 +1388,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
 			emit_rev32(rt, rt, ctx);
 			goto emit_bswap_uxt;
 		case 64:
-			/* Because of the usage of ARM_LR */
-			ctx->seen |= SEEN_CALL;
 			emit_rev32(ARM_LR, rt, ctx);
 			emit_rev32(rt, rd, ctx);
 			emit(ARM_MOV_R(rd, ARM_LR), ctx);