summary refs log tree commit diff
path: root/kernel/bpf
diff options
context:
space:
mode:
authorDavid S. Miller <davem@davemloft.net>2019-06-20 00:06:27 -0400
committerDavid S. Miller <davem@davemloft.net>2019-06-20 00:06:27 -0400
commitdca73a65a68329ee386d3ff473152bac66eaab39 (patch)
tree97c41afb932bdd6cbe67e7ffc38bfe5952c97798 /kernel/bpf
parent497ad9f5b2dc86b733761b9afa44ecfa2f17be65 (diff)
parent94079b64255fe40b9b53fd2e4081f68b9b14f54a (diff)
downloadlinux-dca73a65a68329ee386d3ff473152bac66eaab39.tar.gz
Merge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
Alexei Starovoitov says:

====================
pull-request: bpf-next 2019-06-19

The following pull-request contains BPF updates for your *net-next* tree.

The main changes are:

1) new SO_REUSEPORT_DETACH_BPF setsocktopt, from Martin.

2) BTF based map definition, from Andrii.

3) support bpf_map_lookup_elem for xskmap, from Jonathan.

4) bounded loops and scalar precision logic in the verifier, from Alexei.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
Diffstat (limited to 'kernel/bpf')
-rw-r--r--kernel/bpf/Makefile1
-rw-r--r--kernel/bpf/devmap.c2
-rw-r--r--kernel/bpf/verifier.c793
-rw-r--r--kernel/bpf/xskmap.c9
4 files changed, 734 insertions, 71 deletions
diff --git a/kernel/bpf/Makefile b/kernel/bpf/Makefile
index 4c2fa3ac56f6..29d781061cd5 100644
--- a/kernel/bpf/Makefile
+++ b/kernel/bpf/Makefile
@@ -1,5 +1,6 @@
 # SPDX-License-Identifier: GPL-2.0
 obj-y := core.o
+CFLAGS_core.o += $(call cc-disable-warning, override-init)
 
 obj-$(CONFIG_BPF_SYSCALL) += syscall.o verifier.o inode.o helpers.o tnum.o
 obj-$(CONFIG_BPF_SYSCALL) += hashtab.o arraymap.o percpu_freelist.o bpf_lru_list.o lpm_trie.o map_in_map.o
diff --git a/kernel/bpf/devmap.c b/kernel/bpf/devmap.c
index b84c44505e06..40e86a7e0ef0 100644
--- a/kernel/bpf/devmap.c
+++ b/kernel/bpf/devmap.c
@@ -80,8 +80,8 @@ static u64 dev_map_bitmap_size(const union bpf_attr *attr)
 static struct bpf_map *dev_map_alloc(union bpf_attr *attr)
 {
 	struct bpf_dtab *dtab;
-	int err = -EINVAL;
 	u64 cost;
+	int err;
 
 	if (!capable(CAP_NET_ADMIN))
 		return ERR_PTR(-EPERM);
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 1e9d10b32984..0e079b2298f8 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -326,7 +326,8 @@ static bool type_is_sk_pointer(enum bpf_reg_type type)
 {
 	return type == PTR_TO_SOCKET ||
 		type == PTR_TO_SOCK_COMMON ||
-		type == PTR_TO_TCP_SOCK;
+		type == PTR_TO_TCP_SOCK ||
+		type == PTR_TO_XDP_SOCK;
 }
 
 static bool reg_type_may_be_null(enum bpf_reg_type type)
@@ -398,6 +399,7 @@ static const char * const reg_type_str[] = {
 	[PTR_TO_TCP_SOCK]	= "tcp_sock",
 	[PTR_TO_TCP_SOCK_OR_NULL] = "tcp_sock_or_null",
 	[PTR_TO_TP_BUFFER]	= "tp_buffer",
+	[PTR_TO_XDP_SOCK]	= "xdp_sock",
 };
 
 static char slot_type_char[] = {
@@ -445,12 +447,12 @@ static void print_verifier_state(struct bpf_verifier_env *env,
 		verbose(env, " R%d", i);
 		print_liveness(env, reg->live);
 		verbose(env, "=%s", reg_type_str[t]);
+		if (t == SCALAR_VALUE && reg->precise)
+			verbose(env, "P");
 		if ((t == SCALAR_VALUE || t == PTR_TO_STACK) &&
 		    tnum_is_const(reg->var_off)) {
 			/* reg->off should be 0 for SCALAR_VALUE */
 			verbose(env, "%lld", reg->var_off.value + reg->off);
-			if (t == PTR_TO_STACK)
-				verbose(env, ",call_%d", func(env, reg)->callsite);
 		} else {
 			verbose(env, "(id=%d", reg->id);
 			if (reg_type_may_be_refcounted_or_null(t))
@@ -512,11 +514,17 @@ static void print_verifier_state(struct bpf_verifier_env *env,
 			continue;
 		verbose(env, " fp%d", (-i - 1) * BPF_REG_SIZE);
 		print_liveness(env, state->stack[i].spilled_ptr.live);
-		if (state->stack[i].slot_type[0] == STACK_SPILL)
-			verbose(env, "=%s",
-				reg_type_str[state->stack[i].spilled_ptr.type]);
-		else
+		if (state->stack[i].slot_type[0] == STACK_SPILL) {
+			reg = &state->stack[i].spilled_ptr;
+			t = reg->type;
+			verbose(env, "=%s", reg_type_str[t]);
+			if (t == SCALAR_VALUE && reg->precise)
+				verbose(env, "P");
+			if (t == SCALAR_VALUE && tnum_is_const(reg->var_off))
+				verbose(env, "%lld", reg->var_off.value + reg->off);
+		} else {
 			verbose(env, "=%s", types_buf);
+		}
 	}
 	if (state->acquired_refs && state->refs[0].id) {
 		verbose(env, " refs=%d", state->refs[0].id);
@@ -665,6 +673,13 @@ static void free_func_state(struct bpf_func_state *state)
 	kfree(state);
 }
 
+static void clear_jmp_history(struct bpf_verifier_state *state)
+{
+	kfree(state->jmp_history);
+	state->jmp_history = NULL;
+	state->jmp_history_cnt = 0;
+}
+
 static void free_verifier_state(struct bpf_verifier_state *state,
 				bool free_self)
 {
@@ -674,6 +689,7 @@ static void free_verifier_state(struct bpf_verifier_state *state,
 		free_func_state(state->frame[i]);
 		state->frame[i] = NULL;
 	}
+	clear_jmp_history(state);
 	if (free_self)
 		kfree(state);
 }
@@ -701,8 +717,18 @@ static int copy_verifier_state(struct bpf_verifier_state *dst_state,
 			       const struct bpf_verifier_state *src)
 {
 	struct bpf_func_state *dst;
+	u32 jmp_sz = sizeof(struct bpf_idx_pair) * src->jmp_history_cnt;
 	int i, err;
 
+	if (dst_state->jmp_history_cnt < src->jmp_history_cnt) {
+		kfree(dst_state->jmp_history);
+		dst_state->jmp_history = kmalloc(jmp_sz, GFP_USER);
+		if (!dst_state->jmp_history)
+			return -ENOMEM;
+	}
+	memcpy(dst_state->jmp_history, src->jmp_history, jmp_sz);
+	dst_state->jmp_history_cnt = src->jmp_history_cnt;
+
 	/* if dst has more stack frames then src frame, free them */
 	for (i = src->curframe + 1; i <= dst_state->curframe; i++) {
 		free_func_state(dst_state->frame[i]);
@@ -711,6 +737,10 @@ static int copy_verifier_state(struct bpf_verifier_state *dst_state,
 	dst_state->speculative = src->speculative;
 	dst_state->curframe = src->curframe;
 	dst_state->active_spin_lock = src->active_spin_lock;
+	dst_state->branches = src->branches;
+	dst_state->parent = src->parent;
+	dst_state->first_insn_idx = src->first_insn_idx;
+	dst_state->last_insn_idx = src->last_insn_idx;
 	for (i = 0; i <= src->curframe; i++) {
 		dst = dst_state->frame[i];
 		if (!dst) {
@@ -726,6 +756,23 @@ static int copy_verifier_state(struct bpf_verifier_state *dst_state,
 	return 0;
 }
 
+static void update_branch_counts(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
+{
+	while (st) {
+		u32 br = --st->branches;
+
+		/* WARN_ON(br > 1) technically makes sense here,
+		 * but see comment in push_stack(), hence:
+		 */
+		WARN_ONCE((int)br < 0,
+			  "BUG update_branch_counts:branches_to_explore=%d\n",
+			  br);
+		if (br)
+			break;
+		st = st->parent;
+	}
+}
+
 static int pop_stack(struct bpf_verifier_env *env, int *prev_insn_idx,
 		     int *insn_idx)
 {
@@ -779,6 +826,18 @@ static struct bpf_verifier_state *push_stack(struct bpf_verifier_env *env,
 			env->stack_size);
 		goto err;
 	}
+	if (elem->st.parent) {
+		++elem->st.parent->branches;
+		/* WARN_ON(branches > 2) technically makes sense here,
+		 * but
+		 * 1. speculative states will bump 'branches' for non-branch
+		 * instructions
+		 * 2. is_state_visited() heuristics may decide not to create
+		 * a new state for a sequence of branches and all such current
+		 * and cloned states will be pointing to a single parent state
+		 * which might have large 'branches' count.
+		 */
+	}
 	return &elem->st;
 err:
 	free_verifier_state(env->cur_state, true);
@@ -926,6 +985,9 @@ static void __mark_reg_unbounded(struct bpf_reg_state *reg)
 	reg->smax_value = S64_MAX;
 	reg->umin_value = 0;
 	reg->umax_value = U64_MAX;
+
+	/* constant backtracking is enabled for root only for now */
+	reg->precise = capable(CAP_SYS_ADMIN) ? false : true;
 }
 
 /* Mark a register as having a completely unknown (scalar) value. */
@@ -1337,6 +1399,389 @@ static int check_reg_arg(struct bpf_verifier_env *env, u32 regno,
 	return 0;
 }
 
+/* for any branch, call, exit record the history of jmps in the given state */
+static int push_jmp_history(struct bpf_verifier_env *env,
+			    struct bpf_verifier_state *cur)
+{
+	u32 cnt = cur->jmp_history_cnt;
+	struct bpf_idx_pair *p;
+
+	cnt++;
+	p = krealloc(cur->jmp_history, cnt * sizeof(*p), GFP_USER);
+	if (!p)
+		return -ENOMEM;
+	p[cnt - 1].idx = env->insn_idx;
+	p[cnt - 1].prev_idx = env->prev_insn_idx;
+	cur->jmp_history = p;
+	cur->jmp_history_cnt = cnt;
+	return 0;
+}
+
+/* Backtrack one insn at a time. If idx is not at the top of recorded
+ * history then previous instruction came from straight line execution.
+ */
+static int get_prev_insn_idx(struct bpf_verifier_state *st, int i,
+			     u32 *history)
+{
+	u32 cnt = *history;
+
+	if (cnt && st->jmp_history[cnt - 1].idx == i) {
+		i = st->jmp_history[cnt - 1].prev_idx;
+		(*history)--;
+	} else {
+		i--;
+	}
+	return i;
+}
+
+/* For given verifier state backtrack_insn() is called from the last insn to
+ * the first insn. Its purpose is to compute a bitmask of registers and
+ * stack slots that needs precision in the parent verifier state.
+ */
+static int backtrack_insn(struct bpf_verifier_env *env, int idx,
+			  u32 *reg_mask, u64 *stack_mask)
+{
+	const struct bpf_insn_cbs cbs = {
+		.cb_print	= verbose,
+		.private_data	= env,
+	};
+	struct bpf_insn *insn = env->prog->insnsi + idx;
+	u8 class = BPF_CLASS(insn->code);
+	u8 opcode = BPF_OP(insn->code);
+	u8 mode = BPF_MODE(insn->code);
+	u32 dreg = 1u << insn->dst_reg;
+	u32 sreg = 1u << insn->src_reg;
+	u32 spi;
+
+	if (insn->code == 0)
+		return 0;
+	if (env->log.level & BPF_LOG_LEVEL) {
+		verbose(env, "regs=%x stack=%llx before ", *reg_mask, *stack_mask);
+		verbose(env, "%d: ", idx);
+		print_bpf_insn(&cbs, insn, env->allow_ptr_leaks);
+	}
+
+	if (class == BPF_ALU || class == BPF_ALU64) {
+		if (!(*reg_mask & dreg))
+			return 0;
+		if (opcode == BPF_MOV) {
+			if (BPF_SRC(insn->code) == BPF_X) {
+				/* dreg = sreg
+				 * dreg needs precision after this insn
+				 * sreg needs precision before this insn
+				 */
+				*reg_mask &= ~dreg;
+				*reg_mask |= sreg;
+			} else {
+				/* dreg = K
+				 * dreg needs precision after this insn.
+				 * Corresponding register is already marked
+				 * as precise=true in this verifier state.
+				 * No further markings in parent are necessary
+				 */
+				*reg_mask &= ~dreg;
+			}
+		} else {
+			if (BPF_SRC(insn->code) == BPF_X) {
+				/* dreg += sreg
+				 * both dreg and sreg need precision
+				 * before this insn
+				 */
+				*reg_mask |= sreg;
+			} /* else dreg += K
+			   * dreg still needs precision before this insn
+			   */
+		}
+	} else if (class == BPF_LDX) {
+		if (!(*reg_mask & dreg))
+			return 0;
+		*reg_mask &= ~dreg;
+
+		/* scalars can only be spilled into stack w/o losing precision.
+		 * Load from any other memory can be zero extended.
+		 * The desire to keep that precision is already indicated
+		 * by 'precise' mark in corresponding register of this state.
+		 * No further tracking necessary.
+		 */
+		if (insn->src_reg != BPF_REG_FP)
+			return 0;
+		if (BPF_SIZE(insn->code) != BPF_DW)
+			return 0;
+
+		/* dreg = *(u64 *)[fp - off] was a fill from the stack.
+		 * that [fp - off] slot contains scalar that needs to be
+		 * tracked with precision
+		 */
+		spi = (-insn->off - 1) / BPF_REG_SIZE;
+		if (spi >= 64) {
+			verbose(env, "BUG spi %d\n", spi);
+			WARN_ONCE(1, "verifier backtracking bug");
+			return -EFAULT;
+		}
+		*stack_mask |= 1ull << spi;
+	} else if (class == BPF_STX) {
+		if (*reg_mask & dreg)
+			/* stx shouldn't be using _scalar_ dst_reg
+			 * to access memory. It means backtracking
+			 * encountered a case of pointer subtraction.
+			 */
+			return -ENOTSUPP;
+		/* scalars can only be spilled into stack */
+		if (insn->dst_reg != BPF_REG_FP)
+			return 0;
+		if (BPF_SIZE(insn->code) != BPF_DW)
+			return 0;
+		spi = (-insn->off - 1) / BPF_REG_SIZE;
+		if (spi >= 64) {
+			verbose(env, "BUG spi %d\n", spi);
+			WARN_ONCE(1, "verifier backtracking bug");
+			return -EFAULT;
+		}
+		if (!(*stack_mask & (1ull << spi)))
+			return 0;
+		*stack_mask &= ~(1ull << spi);
+		*reg_mask |= sreg;
+	} else if (class == BPF_JMP || class == BPF_JMP32) {
+		if (opcode == BPF_CALL) {
+			if (insn->src_reg == BPF_PSEUDO_CALL)
+				return -ENOTSUPP;
+			/* regular helper call sets R0 */
+			*reg_mask &= ~1;
+			if (*reg_mask & 0x3f) {
+				/* if backtracing was looking for registers R1-R5
+				 * they should have been found already.
+				 */
+				verbose(env, "BUG regs %x\n", *reg_mask);
+				WARN_ONCE(1, "verifier backtracking bug");
+				return -EFAULT;
+			}
+		} else if (opcode == BPF_EXIT) {
+			return -ENOTSUPP;
+		}
+	} else if (class == BPF_LD) {
+		if (!(*reg_mask & dreg))
+			return 0;
+		*reg_mask &= ~dreg;
+		/* It's ld_imm64 or ld_abs or ld_ind.
+		 * For ld_imm64 no further tracking of precision
+		 * into parent is necessary
+		 */
+		if (mode == BPF_IND || mode == BPF_ABS)
+			/* to be analyzed */
+			return -ENOTSUPP;
+	} else if (class == BPF_ST) {
+		if (*reg_mask & dreg)
+			/* likely pointer subtraction */
+			return -ENOTSUPP;
+	}
+	return 0;
+}
+
+/* the scalar precision tracking algorithm:
+ * . at the start all registers have precise=false.
+ * . scalar ranges are tracked as normal through alu and jmp insns.
+ * . once precise value of the scalar register is used in:
+ *   .  ptr + scalar alu
+ *   . if (scalar cond K|scalar)
+ *   .  helper_call(.., scalar, ...) where ARG_CONST is expected
+ *   backtrack through the verifier states and mark all registers and
+ *   stack slots with spilled constants that these scalar regisers
+ *   should be precise.
+ * . during state pruning two registers (or spilled stack slots)
+ *   are equivalent if both are not precise.
+ *
+ * Note the verifier cannot simply walk register parentage chain,
+ * since many different registers and stack slots could have been
+ * used to compute single precise scalar.
+ *
+ * The approach of starting with precise=true for all registers and then
+ * backtrack to mark a register as not precise when the verifier detects
+ * that program doesn't care about specific value (e.g., when helper
+ * takes register as ARG_ANYTHING parameter) is not safe.
+ *
+ * It's ok to walk single parentage chain of the verifier states.
+ * It's possible that this backtracking will go all the way till 1st insn.
+ * All other branches will be explored for needing precision later.
+ *
+ * The backtracking needs to deal with cases like:
+ *   R8=map_value(id=0,off=0,ks=4,vs=1952,imm=0) R9_w=map_value(id=0,off=40,ks=4,vs=1952,imm=0)
+ * r9 -= r8
+ * r5 = r9
+ * if r5 > 0x79f goto pc+7
+ *    R5_w=inv(id=0,umax_value=1951,var_off=(0x0; 0x7ff))
+ * r5 += 1
+ * ...
+ * call bpf_perf_event_output#25
+ *   where .arg5_type = ARG_CONST_SIZE_OR_ZERO
+ *
+ * and this case:
+ * r6 = 1
+ * call foo // uses callee's r6 inside to compute r0
+ * r0 += r6
+ * if r0 == 0 goto
+ *
+ * to track above reg_mask/stack_mask needs to be independent for each frame.
+ *
+ * Also if parent's curframe > frame where backtracking started,
+ * the verifier need to mark registers in both frames, otherwise callees
+ * may incorrectly prune callers. This is similar to
+ * commit 7640ead93924 ("bpf: verifier: make sure callees don't prune with caller differences")
+ *
+ * For now backtracking falls back into conservative marking.
+ */
+static void mark_all_scalars_precise(struct bpf_verifier_env *env,
+				     struct bpf_verifier_state *st)
+{
+	struct bpf_func_state *func;
+	struct bpf_reg_state *reg;
+	int i, j;
+
+	/* big hammer: mark all scalars precise in this path.
+	 * pop_stack may still get !precise scalars.
+	 */
+	for (; st; st = st->parent)
+		for (i = 0; i <= st->curframe; i++) {
+			func = st->frame[i];
+			for (j = 0; j < BPF_REG_FP; j++) {
+				reg = &func->regs[j];
+				if (reg->type != SCALAR_VALUE)
+					continue;
+				reg->precise = true;
+			}
+			for (j = 0; j < func->allocated_stack / BPF_REG_SIZE; j++) {
+				if (func->stack[j].slot_type[0] != STACK_SPILL)
+					continue;
+				reg = &func->stack[j].spilled_ptr;
+				if (reg->type != SCALAR_VALUE)
+					continue;
+				reg->precise = true;
+			}
+		}
+}
+
+static int mark_chain_precision(struct bpf_verifier_env *env, int regno)
+{
+	struct bpf_verifier_state *st = env->cur_state;
+	int first_idx = st->first_insn_idx;
+	int last_idx = env->insn_idx;
+	struct bpf_func_state *func;
+	struct bpf_reg_state *reg;
+	u32 reg_mask = 1u << regno;
+	u64 stack_mask = 0;
+	bool skip_first = true;
+	int i, err;
+
+	if (!env->allow_ptr_leaks)
+		/* backtracking is root only for now */
+		return 0;
+
+	func = st->frame[st->curframe];
+	reg = &func->regs[regno];
+	if (reg->type != SCALAR_VALUE) {
+		WARN_ONCE(1, "backtracing misuse");
+		return -EFAULT;
+	}
+	if (reg->precise)
+		return 0;
+	func->regs[regno].precise = true;
+
+	for (;;) {
+		DECLARE_BITMAP(mask, 64);
+		bool new_marks = false;
+		u32 history = st->jmp_history_cnt;
+
+		if (env->log.level & BPF_LOG_LEVEL)
+			verbose(env, "last_idx %d first_idx %d\n", last_idx, first_idx);
+		for (i = last_idx;;) {
+			if (skip_first) {
+				err = 0;
+				skip_first = false;
+			} else {
+				err = backtrack_insn(env, i, &reg_mask, &stack_mask);
+			}
+			if (err == -ENOTSUPP) {
+				mark_all_scalars_precise(env, st);
+				return 0;
+			} else if (err) {
+				return err;
+			}
+			if (!reg_mask && !stack_mask)
+				/* Found assignment(s) into tracked register in this state.
+				 * Since this state is already marked, just return.
+				 * Nothing to be tracked further in the parent state.
+				 */
+				return 0;
+			if (i == first_idx)
+				break;
+			i = get_prev_insn_idx(st, i, &history);
+			if (i >= env->prog->len) {
+				/* This can happen if backtracking reached insn 0
+				 * and there are still reg_mask or stack_mask
+				 * to backtrack.
+				 * It means the backtracking missed the spot where
+				 * particular register was initialized with a constant.
+				 */
+				verbose(env, "BUG backtracking idx %d\n", i);
+				WARN_ONCE(1, "verifier backtracking bug");
+				return -EFAULT;
+			}
+		}
+		st = st->parent;
+		if (!st)
+			break;
+
+		func = st->frame[st->curframe];
+		bitmap_from_u64(mask, reg_mask);
+		for_each_set_bit(i, mask, 32) {
+			reg = &func->regs[i];
+			if (reg->type != SCALAR_VALUE)
+				continue;
+			if (!reg->precise)
+				new_marks = true;
+			reg->precise = true;
+		}
+
+		bitmap_from_u64(mask, stack_mask);
+		for_each_set_bit(i, mask, 64) {
+			if (i >= func->allocated_stack / BPF_REG_SIZE) {
+				/* This can happen if backtracking
+				 * is propagating stack precision where
+				 * caller has larger stack frame
+				 * than callee, but backtrack_insn() should
+				 * have returned -ENOTSUPP.
+				 */
+				verbose(env, "BUG spi %d stack_size %d\n",
+					i, func->allocated_stack);
+				WARN_ONCE(1, "verifier backtracking bug");
+				return -EFAULT;
+			}
+
+			if (func->stack[i].slot_type[0] != STACK_SPILL)
+				continue;
+			reg = &func->stack[i].spilled_ptr;
+			if (reg->type != SCALAR_VALUE)
+				continue;
+			if (!reg->precise)
+				new_marks = true;
+			reg->precise = true;
+		}
+		if (env->log.level & BPF_LOG_LEVEL) {
+			print_verifier_state(env, func);
+			verbose(env, "parent %s regs=%x stack=%llx marks\n",
+				new_marks ? "didn't have" : "already had",
+				reg_mask, stack_mask);
+		}
+
+		if (!new_marks)
+			break;
+
+		last_idx = st->last_insn_idx;
+		first_idx = st->first_insn_idx;
+	}
+	return 0;
+}
+
+
 static bool is_spillable_regtype(enum bpf_reg_type type)
 {
 	switch (type) {
@@ -1355,6 +1800,7 @@ static bool is_spillable_regtype(enum bpf_reg_type type)
 	case PTR_TO_SOCK_COMMON_OR_NULL:
 	case PTR_TO_TCP_SOCK:
 	case PTR_TO_TCP_SOCK_OR_NULL:
+	case PTR_TO_XDP_SOCK:
 		return true;
 	default:
 		return false;
@@ -1367,6 +1813,23 @@ static bool register_is_null(struct bpf_reg_state *reg)
 	return reg->type == SCALAR_VALUE && tnum_equals_const(reg->var_off, 0);
 }
 
+static bool register_is_const(struct bpf_reg_state *reg)
+{
+	return reg->type == SCALAR_VALUE && tnum_is_const(reg->var_off);
+}
+
+static void save_register_state(struct bpf_func_state *state,
+				int spi, struct bpf_reg_state *reg)
+{
+	int i;
+
+	state->stack[spi].spilled_ptr = *reg;
+	state->stack[spi].spilled_ptr.live |= REG_LIVE_WRITTEN;
+
+	for (i = 0; i < BPF_REG_SIZE; i++)
+		state->stack[spi].slot_type[i] = STACK_SPILL;
+}
+
 /* check_stack_read/write functions track spill/fill of registers,
  * stack boundary and alignment are checked in check_mem_access()
  */
@@ -1376,7 +1839,8 @@ static int check_stack_write(struct bpf_verifier_env *env,
 {
 	struct bpf_func_state *cur; /* state of the current function */
 	int i, slot = -off - 1, spi = slot / BPF_REG_SIZE, err;
-	enum bpf_reg_type type;
+	u32 dst_reg = env->prog->insnsi[insn_idx].dst_reg;
+	struct bpf_reg_state *reg = NULL;
 
 	err = realloc_func_state(state, round_up(slot + 1, BPF_REG_SIZE),
 				 state->acquired_refs, true);
@@ -1393,27 +1857,48 @@ static int check_stack_write(struct bpf_verifier_env *env,
 	}
 
 	cur = env->cur_state->frame[env->cur_state->curframe];
-	if (value_regno >= 0 &&
-	    is_spillable_regtype((type = cur->regs[value_regno].type))) {
-
+	if (value_regno >= 0)
+		reg = &cur->regs[value_regno];
+
+	if (reg && size == BPF_REG_SIZE && register_is_const(reg) &&
+	    !register_is_null(reg) && env->allow_ptr_leaks) {
+		if (dst_reg != BPF_REG_FP) {
+			/* The backtracking logic can only recognize explicit
+			 * stack slot address like [fp - 8]. Other spill of
+			 * scalar via different register has to be conervative.
+			 * Backtrack from here and mark all registers as precise
+			 * that contributed into 'reg' being a constant.
+			 */
+			err = mark_chain_precision(env, value_regno);
+			if (err)
+				return err;
+		}
+		save_register_state(state, spi, reg);
+	} else if (reg && is_spillable_regtype(reg->type)) {
 		/* register containing pointer is being spilled into stack */
 		if (size != BPF_REG_SIZE) {
+			verbose_linfo(env, insn_idx, "; ");
 			verbose(env, "invalid size of register spill\n");
 			return -EACCES;
 		}
 
-		if (state != cur && type == PTR_TO_STACK) {
+		if (state != cur && reg->type == PTR_TO_STACK) {
 			verbose(env, "cannot spill pointers to stack into stack frame of the caller\n");
 			return -EINVAL;
 		}
 
-		/* save register state */
-		state->stack[spi].spilled_ptr = cur->regs[value_regno];
-		state->stack[spi].spilled_ptr.live |= REG_LIVE_WRITTEN;
+		if (!env->allow_ptr_leaks) {
+			bool sanitize = false;
 
-		for (i = 0; i < BPF_REG_SIZE; i++) {
-			if (state->stack[spi].slot_type[i] == STACK_MISC &&
-			    !env->allow_ptr_leaks) {
+			if (state->stack[spi].slot_type[0] == STACK_SPILL &&
+			    register_is_const(&state->stack[spi].spilled_ptr))
+				sanitize = true;
+			for (i = 0; i < BPF_REG_SIZE; i++)
+				if (state->stack[spi].slot_type[i] == STACK_MISC) {
+					sanitize = true;
+					break;
+				}
+			if (sanitize) {
 				int *poff = &env->insn_aux_data[insn_idx].sanitize_stack_off;
 				int soff = (-spi - 1) * BPF_REG_SIZE;
 
@@ -1436,8 +1921,8 @@ static int check_stack_write(struct bpf_verifier_env *env,
 				}
 				*poff = soff;
 			}
-			state->stack[spi].slot_type[i] = STACK_SPILL;
 		}
+		save_register_state(state, spi, reg);
 	} else {
 		u8 type = STACK_MISC;
 
@@ -1460,9 +1945,13 @@ static int check_stack_write(struct bpf_verifier_env *env,
 			state->stack[spi].spilled_ptr.live |= REG_LIVE_WRITTEN;
 
 		/* when we zero initialize stack slots mark them as such */
-		if (value_regno >= 0 &&
-		    register_is_null(&cur->regs[value_regno]))
+		if (reg && register_is_null(reg)) {
+			/* backtracking doesn't work for STACK_ZERO yet. */
+			err = mark_chain_precision(env, value_regno);
+			if (err)
+				return err;
 			type = STACK_ZERO;
+		}
 
 		/* Mark slots affected by this stack write. */
 		for (i = 0; i < size; i++)
@@ -1479,6 +1968,7 @@ static int check_stack_read(struct bpf_verifier_env *env,
 	struct bpf_verifier_state *vstate = env->cur_state;
 	struct bpf_func_state *state = vstate->frame[vstate->curframe];
 	int i, slot = -off - 1, spi = slot / BPF_REG_SIZE;
+	struct bpf_reg_state *reg;
 	u8 *stype;
 
 	if (reg_state->allocated_stack <= slot) {
@@ -1487,11 +1977,21 @@ static int check_stack_read(struct bpf_verifier_env *env,
 		return -EACCES;
 	}
 	stype = reg_state->stack[spi].slot_type;
+	reg = &reg_state->stack[spi].spilled_ptr;
 
 	if (stype[0] == STACK_SPILL) {
 		if (size != BPF_REG_SIZE) {
-			verbose(env, "invalid size of register spill\n");
-			return -EACCES;
+			if (reg->type != SCALAR_VALUE) {
+				verbose_linfo(env, env->insn_idx, "; ");
+				verbose(env, "invalid size of register fill\n");
+				return -EACCES;
+			}
+			if (value_regno >= 0) {
+				mark_reg_unknown(env, state->regs, value_regno);
+				state->regs[value_regno].live |= REG_LIVE_WRITTEN;
+			}
+			mark_reg_read(env, reg, reg->parent, REG_LIVE_READ64);
+			return 0;
 		}
 		for (i = 1; i < BPF_REG_SIZE; i++) {
 			if (stype[(slot - i) % BPF_REG_SIZE] != STACK_SPILL) {
@@ -1502,17 +2002,14 @@ static int check_stack_read(struct bpf_verifier_env *env,
 
 		if (value_regno >= 0) {
 			/* restore register state from stack */
-			state->regs[value_regno] = reg_state->stack[spi].spilled_ptr;
+			state->regs[value_regno] = *reg;
 			/* mark reg as written since spilled pointer state likely
 			 * has its liveness marks cleared by is_state_visited()
 			 * which resets stack/reg liveness for state transitions
 			 */
 			state->regs[value_regno].live |= REG_LIVE_WRITTEN;
 		}
-		mark_reg_read(env, &reg_state->stack[spi].spilled_ptr,
-			      reg_state->stack[spi].spilled_ptr.parent,
-			      REG_LIVE_READ64);
-		return 0;
+		mark_reg_read(env, reg, reg->parent, REG_LIVE_READ64);
 	} else {
 		int zeros = 0;
 
@@ -1527,23 +2024,32 @@ static int check_stack_read(struct bpf_verifier_env *env,
 				off, i, size);
 			return -EACCES;
 		}
-		mark_reg_read(env, &reg_state->stack[spi].spilled_ptr,
-			      reg_state->stack[spi].spilled_ptr.parent,
-			      REG_LIVE_READ64);
+		mark_reg_read(env, reg, reg->parent, REG_LIVE_READ64);
 		if (value_regno >= 0) {
 			if (zeros == size) {
 				/* any size read into register is zero extended,
 				 * so the whole register == const_zero
 				 */
 				__mark_reg_const_zero(&state->regs[value_regno]);
+				/* backtracking doesn't support STACK_ZERO yet,
+				 * so mark it precise here, so that later
+				 * backtracking can stop here.
+				 * Backtracking may not need this if this register
+				 * doesn't participate in pointer adjustment.
+				 * Forward propagation of precise flag is not
+				 * necessary either. This mark is only to stop
+				 * backtracking. Any register that contributed
+				 * to const 0 was marked precise before spill.
+				 */
+				state->regs[value_regno].precise = true;
 			} else {
 				/* have read misc data from the stack */
 				mark_reg_unknown(env, state->regs, value_regno);
 			}
 			state->regs[value_regno].live |= REG_LIVE_WRITTEN;
 		}
-		return 0;
 	}
+	return 0;
 }
 
 static int check_stack_access(struct bpf_verifier_env *env,
@@ -1835,6 +2341,9 @@ static int check_sock_access(struct bpf_verifier_env *env, int insn_idx,
 	case PTR_TO_TCP_SOCK:
 		valid = bpf_tcp_sock_is_valid_access(off, size, t, &info);
 		break;
+	case PTR_TO_XDP_SOCK:
+		valid = bpf_xdp_sock_is_valid_access(off, size, t, &info);
+		break;
 	default:
 		valid = false;
 	}
@@ -1999,6 +2508,9 @@ static int check_ptr_alignment(struct bpf_verifier_env *env,
 	case PTR_TO_TCP_SOCK:
 		pointer_desc = "tcp_sock ";
 		break;
+	case PTR_TO_XDP_SOCK:
+		pointer_desc = "xdp_sock ";
+		break;
 	default:
 		break;
 	}
@@ -2398,7 +2910,7 @@ static int check_stack_boundary(struct bpf_verifier_env *env, int regno,
 {
 	struct bpf_reg_state *reg = reg_state(env, regno);
 	struct bpf_func_state *state = func(env, reg);
-	int err, min_off, max_off, i, slot, spi;
+	int err, min_off, max_off, i, j, slot, spi;
 
 	if (reg->type != PTR_TO_STACK) {
 		/* Allow zero-byte read from NULL, regardless of pointer type */
@@ -2486,6 +2998,14 @@ static int check_stack_boundary(struct bpf_verifier_env *env, int regno,
 			*stype = STACK_MISC;
 			goto mark;
 		}
+		if (state->stack[spi].slot_type[0] == STACK_SPILL &&
+		    state->stack[spi].spilled_ptr.type == SCALAR_VALUE) {
+			__mark_reg_unknown(&state->stack[spi].spilled_ptr);
+			for (j = 0; j < BPF_REG_SIZE; j++)
+				state->stack[spi].slot_type[j] = STACK_MISC;
+			goto mark;
+		}
+
 err:
 		if (tnum_is_const(reg->var_off)) {
 			verbose(env, "invalid indirect read from stack off %d+%d size %d\n",
@@ -2837,6 +3357,8 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
 		err = check_helper_mem_access(env, regno - 1,
 					      reg->umax_value,
 					      zero_size_allowed, meta);
+		if (!err)
+			err = mark_chain_precision(env, regno);
 	} else if (arg_type_is_int_ptr(arg_type)) {
 		int size = int_ptr_type_to_size(arg_type);
 
@@ -2897,10 +3419,14 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
 	 * appear.
 	 */
 	case BPF_MAP_TYPE_CPUMAP:
-	case BPF_MAP_TYPE_XSKMAP:
 		if (func_id != BPF_FUNC_redirect_map)
 			goto error;
 		break;
+	case BPF_MAP_TYPE_XSKMAP:
+		if (func_id != BPF_FUNC_redirect_map &&
+		    func_id != BPF_FUNC_map_lookup_elem)
+			goto error;
+		break;
 	case BPF_MAP_TYPE_ARRAY_OF_MAPS:
 	case BPF_MAP_TYPE_HASH_OF_MAPS:
 		if (func_id != BPF_FUNC_map_lookup_elem)
@@ -3791,6 +4317,7 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
 	case PTR_TO_SOCK_COMMON_OR_NULL:
 	case PTR_TO_TCP_SOCK:
 	case PTR_TO_TCP_SOCK_OR_NULL:
+	case PTR_TO_XDP_SOCK:
 		verbose(env, "R%d pointer arithmetic on %s prohibited\n",
 			dst, reg_type_str[ptr_reg->type]);
 		return -EACCES;
@@ -4268,6 +4795,7 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 	struct bpf_reg_state *regs = state->regs, *dst_reg, *src_reg;
 	struct bpf_reg_state *ptr_reg = NULL, off_reg = {0};
 	u8 opcode = BPF_OP(insn->code);
+	int err;
 
 	dst_reg = &regs[insn->dst_reg];
 	src_reg = NULL;
@@ -4294,11 +4822,17 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 				 * This is legal, but we have to reverse our
 				 * src/dest handling in computing the range
 				 */
+				err = mark_chain_precision(env, insn->dst_reg);
+				if (err)
+					return err;
 				return adjust_ptr_min_max_vals(env, insn,
 							       src_reg, dst_reg);
 			}
 		} else if (ptr_reg) {
 			/* pointer += scalar */
+			err = mark_chain_precision(env, insn->src_reg);
+			if (err)
+				return err;
 			return adjust_ptr_min_max_vals(env, insn,
 						       dst_reg, src_reg);
 		}
@@ -5030,6 +5564,9 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
 			if (reg->map_ptr->inner_map_meta) {
 				reg->type = CONST_PTR_TO_MAP;
 				reg->map_ptr = reg->map_ptr->inner_map_meta;
+			} else if (reg->map_ptr->map_type ==
+				   BPF_MAP_TYPE_XSKMAP) {
+				reg->type = PTR_TO_XDP_SOCK;
 			} else {
 				reg->type = PTR_TO_MAP_VALUE;
 			}
@@ -5201,9 +5738,10 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	struct bpf_verifier_state *this_branch = env->cur_state;
 	struct bpf_verifier_state *other_branch;
 	struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs;
-	struct bpf_reg_state *dst_reg, *other_branch_regs;
+	struct bpf_reg_state *dst_reg, *other_branch_regs, *src_reg = NULL;
 	u8 opcode = BPF_OP(insn->code);
 	bool is_jmp32;
+	int pred = -1;
 	int err;
 
 	/* Only conditional jumps are expected to reach here. */
@@ -5228,6 +5766,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
 				insn->src_reg);
 			return -EACCES;
 		}
+		src_reg = &regs[insn->src_reg];
 	} else {
 		if (insn->src_reg != BPF_REG_0) {
 			verbose(env, "BPF_JMP/JMP32 uses reserved fields\n");
@@ -5243,20 +5782,29 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	dst_reg = &regs[insn->dst_reg];
 	is_jmp32 = BPF_CLASS(insn->code) == BPF_JMP32;
 
-	if (BPF_SRC(insn->code) == BPF_K) {
-		int pred = is_branch_taken(dst_reg, insn->imm, opcode,
-					   is_jmp32);
-
-		if (pred == 1) {
-			 /* only follow the goto, ignore fall-through */
-			*insn_idx += insn->off;
-			return 0;
-		} else if (pred == 0) {
-			/* only follow fall-through branch, since
-			 * that's where the program will go
-			 */
-			return 0;
-		}
+	if (BPF_SRC(insn->code) == BPF_K)
+		pred = is_branch_taken(dst_reg, insn->imm,
+				       opcode, is_jmp32);
+	else if (src_reg->type == SCALAR_VALUE &&
+		 tnum_is_const(src_reg->var_off))
+		pred = is_branch_taken(dst_reg, src_reg->var_off.value,
+				       opcode, is_jmp32);
+	if (pred >= 0) {
+		err = mark_chain_precision(env, insn->dst_reg);
+		if (BPF_SRC(insn->code) == BPF_X && !err)
+			err = mark_chain_precision(env, insn->src_reg);
+		if (err)
+			return err;
+	}
+	if (pred == 1) {
+		/* only follow the goto, ignore fall-through */
+		*insn_idx += insn->off;
+		return 0;
+	} else if (pred == 0) {
+		/* only follow fall-through branch, since
+		 * that's where the program will go
+		 */
+		return 0;
 	}
 
 	other_branch = push_stack(env, *insn_idx + insn->off + 1, *insn_idx,
@@ -5616,7 +6164,8 @@ static void init_explored_state(struct bpf_verifier_env *env, int idx)
  * w - next instruction
  * e - edge
  */
-static int push_insn(int t, int w, int e, struct bpf_verifier_env *env)
+static int push_insn(int t, int w, int e, struct bpf_verifier_env *env,
+		     bool loop_ok)
 {
 	int *insn_stack = env->cfg.insn_stack;
 	int *insn_state = env->cfg.insn_state;
@@ -5646,6 +6195,8 @@ static int push_insn(int t, int w, int e, struct bpf_verifier_env *env)
 		insn_stack[env->cfg.cur_stack++] = w;
 		return 1;
 	} else if ((insn_state[w] & 0xF0) == DISCOVERED) {
+		if (loop_ok && env->allow_ptr_leaks)
+			return 0;
 		verbose_linfo(env, t, "%d: ", t);
 		verbose_linfo(env, w, "%d: ", w);
 		verbose(env, "back-edge from insn %d to %d\n", t, w);
@@ -5697,7 +6248,7 @@ peek_stack:
 		if (opcode == BPF_EXIT) {
 			goto mark_explored;
 		} else if (opcode == BPF_CALL) {
-			ret = push_insn(t, t + 1, FALLTHROUGH, env);
+			ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
 			if (ret == 1)
 				goto peek_stack;
 			else if (ret < 0)
@@ -5706,7 +6257,8 @@ peek_stack:
 				init_explored_state(env, t + 1);
 			if (insns[t].src_reg == BPF_PSEUDO_CALL) {
 				init_explored_state(env, t);
-				ret = push_insn(t, t + insns[t].imm + 1, BRANCH, env);
+				ret = push_insn(t, t + insns[t].imm + 1, BRANCH,
+						env, false);
 				if (ret == 1)
 					goto peek_stack;
 				else if (ret < 0)
@@ -5719,11 +6271,16 @@ peek_stack:
 			}
 			/* unconditional jump with single edge */
 			ret = push_insn(t, t + insns[t].off + 1,
-					FALLTHROUGH, env);
+					FALLTHROUGH, env, true);
 			if (ret == 1)
 				goto peek_stack;
 			else if (ret < 0)
 				goto err_free;
+			/* unconditional jmp is not a good pruning point,
+			 * but it's marked, since backtracking needs
+			 * to record jmp history in is_state_visited().
+			 */
+			init_explored_state(env, t + insns[t].off + 1);
 			/* tell verifier to check for equivalent states
 			 * after every call and jump
 			 */
@@ -5732,13 +6289,13 @@ peek_stack:
 		} else {
 			/* conditional jump with two edges */
 			init_explored_state(env, t);
-			ret = push_insn(t, t + 1, FALLTHROUGH, env);
+			ret = push_insn(t, t + 1, FALLTHROUGH, env, true);
 			if (ret == 1)
 				goto peek_stack;
 			else if (ret < 0)
 				goto err_free;
 
-			ret = push_insn(t, t + insns[t].off + 1, BRANCH, env);
+			ret = push_insn(t, t + insns[t].off + 1, BRANCH, env, true);
 			if (ret == 1)
 				goto peek_stack;
 			else if (ret < 0)
@@ -5748,7 +6305,7 @@ peek_stack:
 		/* all other non-branch instructions with single
 		 * fall-through edge
 		 */
-		ret = push_insn(t, t + 1, FALLTHROUGH, env);
+		ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
 		if (ret == 1)
 			goto peek_stack;
 		else if (ret < 0)
@@ -6181,6 +6738,8 @@ static void clean_live_states(struct bpf_verifier_env *env, int insn,
 
 	sl = *explored_state(env, insn);
 	while (sl) {
+		if (sl->state.branches)
+			goto next;
 		if (sl->state.insn_idx != insn ||
 		    sl->state.curframe != cur->curframe)
 			goto next;
@@ -6222,6 +6781,8 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
 	switch (rold->type) {
 	case SCALAR_VALUE:
 		if (rcur->type == SCALAR_VALUE) {
+			if (!rold->precise && !rcur->precise)
+				return true;
 			/* new val must satisfy old val knowledge */
 			return range_within(rold, rcur) &&
 			       tnum_in(rold->var_off, rcur->var_off);
@@ -6294,6 +6855,7 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
 	case PTR_TO_SOCK_COMMON_OR_NULL:
 	case PTR_TO_TCP_SOCK:
 	case PTR_TO_TCP_SOCK_OR_NULL:
+	case PTR_TO_XDP_SOCK:
 		/* Only valid matches are exact, which memcmp() above
 		 * would have accepted
 		 */
@@ -6544,19 +7106,52 @@ static int propagate_liveness(struct bpf_verifier_env *env,
 	return 0;
 }
 
+static bool states_maybe_looping(struct bpf_verifier_state *old,
+				 struct bpf_verifier_state *cur)
+{
+	struct bpf_func_state *fold, *fcur;
+	int i, fr = cur->curframe;
+
+	if (old->curframe != fr)
+		return false;
+
+	fold = old->frame[fr];
+	fcur = cur->frame[fr];
+	for (i = 0; i < MAX_BPF_REG; i++)
+		if (memcmp(&fold->regs[i], &fcur->regs[i],
+			   offsetof(struct bpf_reg_state, parent)))
+			return false;
+	return true;
+}
+
+
 static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 {
 	struct bpf_verifier_state_list *new_sl;
 	struct bpf_verifier_state_list *sl, **pprev;
 	struct bpf_verifier_state *cur = env->cur_state, *new;
 	int i, j, err, states_cnt = 0;
+	bool add_new_state = false;
 
+	cur->last_insn_idx = env->prev_insn_idx;
 	if (!env->insn_aux_data[insn_idx].prune_point)
 		/* this 'insn_idx' instruction wasn't marked, so we will not
 		 * be doing state search here
 		 */
 		return 0;
 
+	/* bpf progs typically have pruning point every 4 instructions
+	 * http://vger.kernel.org/bpfconf2019.html#session-1
+	 * Do not add new state for future pruning if the verifier hasn't seen
+	 * at least 2 jumps and at least 8 instructions.
+	 * This heuristics helps decrease 'total_states' and 'peak_states' metric.
+	 * In tests that amounts to up to 50% reduction into total verifier
+	 * memory consumption and 20% verifier time speedup.
+	 */
+	if (env->jmps_processed - env->prev_jmps_processed >= 2 &&
+	    env->insn_processed - env->prev_insn_processed >= 8)
+		add_new_state = true;
+
 	pprev = explored_state(env, insn_idx);
 	sl = *pprev;
 
@@ -6566,6 +7161,30 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 		states_cnt++;
 		if (sl->state.insn_idx != insn_idx)
 			goto next;
+		if (sl->state.branches) {
+			if (states_maybe_looping(&sl->state, cur) &&
+			    states_equal(env, &sl->state, cur)) {
+				verbose_linfo(env, insn_idx, "; ");
+				verbose(env, "infinite loop detected at insn %d\n", insn_idx);
+				return -EINVAL;
+			}
+			/* if the verifier is processing a loop, avoid adding new state
+			 * too often, since different loop iterations have distinct
+			 * states and may not help future pruning.
+			 * This threshold shouldn't be too low to make sure that
+			 * a loop with large bound will be rejected quickly.
+			 * The most abusive loop will be:
+			 * r1 += 1
+			 * if r1 < 1000000 goto pc-2
+			 * 1M insn_procssed limit / 100 == 10k peak states.
+			 * This threshold shouldn't be too high either, since states
+			 * at the end of the loop are likely to be useful in pruning.
+			 */
+			if (env->jmps_processed - env->prev_jmps_processed < 20 &&
+			    env->insn_processed - env->prev_insn_processed < 100)
+				add_new_state = false;
+			goto miss;
+		}
 		if (states_equal(env, &sl->state, cur)) {
 			sl->hit_cnt++;
 			/* reached equivalent register/stack state,
@@ -6583,7 +7202,15 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 				return err;
 			return 1;
 		}
-		sl->miss_cnt++;
+miss:
+		/* when new state is not going to be added do not increase miss count.
+		 * Otherwise several loop iterations will remove the state
+		 * recorded earlier. The goal of these heuristics is to have
+		 * states from some iterations of the loop (some in the beginning
+		 * and some at the end) to help pruning.
+		 */
+		if (add_new_state)
+			sl->miss_cnt++;
 		/* heuristic to determine whether this state is beneficial
 		 * to keep checking from state equivalence point of view.
 		 * Higher numbers increase max_states_per_insn and verification time,
@@ -6595,6 +7222,11 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 			 */
 			*pprev = sl->next;
 			if (sl->state.frame[0]->regs[0].live & REG_LIVE_DONE) {
+				u32 br = sl->state.branches;
+
+				WARN_ONCE(br,
+					  "BUG live_done but branches_to_explore %d\n",
+					  br);
 				free_verifier_state(&sl->state, false);
 				kfree(sl);
 				env->peak_states--;
@@ -6618,20 +7250,27 @@ next:
 		env->max_states_per_insn = states_cnt;
 
 	if (!env->allow_ptr_leaks && states_cnt > BPF_COMPLEXITY_LIMIT_STATES)
-		return 0;
+		return push_jmp_history(env, cur);
+
+	if (!add_new_state)
+		return push_jmp_history(env, cur);
 
-	/* there were no equivalent states, remember current one.
-	 * technically the current state is not proven to be safe yet,
+	/* There were no equivalent states, remember the current one.
+	 * Technically the current state is not proven to be safe yet,
 	 * but it will either reach outer most bpf_exit (which means it's safe)
-	 * or it will be rejected. Since there are no loops, we won't be
+	 * or it will be rejected. When there are no loops the verifier won't be
 	 * seeing this tuple (frame[0].callsite, frame[1].callsite, .. insn_idx)
-	 * again on the way to bpf_exit
+	 * again on the way to bpf_exit.
+	 * When looping the sl->state.branches will be > 0 and this state
+	 * will not be considered for equivalence until branches == 0.
 	 */
 	new_sl = kzalloc(sizeof(struct bpf_verifier_state_list), GFP_KERNEL);
 	if (!new_sl)
 		return -ENOMEM;
 	env->total_states++;
 	env->peak_states++;
+	env->prev_jmps_processed = env->jmps_processed;
+	env->prev_insn_processed = env->insn_processed;
 
 	/* add new state to the head of linked list */
 	new = &new_sl->state;
@@ -6642,6 +7281,12 @@ next:
 		return err;
 	}
 	new->insn_idx = insn_idx;
+	WARN_ONCE(new->branches != 1,
+		  "BUG is_state_visited:branches_to_explore=%d insn %d\n", new->branches, insn_idx);
+
+	cur->parent = new;
+	cur->first_insn_idx = insn_idx;
+	clear_jmp_history(cur);
 	new_sl->next = *explored_state(env, insn_idx);
 	*explored_state(env, insn_idx) = new_sl;
 	/* connect new state to parentage chain. Current frame needs all
@@ -6651,17 +7296,18 @@ next:
 	 * the state of the call instruction (with WRITTEN set), and r0 comes
 	 * from callee with its full parentage chain, anyway.
 	 */
-	for (j = 0; j <= cur->curframe; j++)
-		for (i = j < cur->curframe ? BPF_REG_6 : 0; i < BPF_REG_FP; i++)
-			cur->frame[j]->regs[i].parent = &new->frame[j]->regs[i];
 	/* clear write marks in current state: the writes we did are not writes
 	 * our child did, so they don't screen off its reads from us.
 	 * (There are no read marks in current state, because reads always mark
 	 * their parent and current state never has children yet.  Only
 	 * explored_states can get read marks.)
 	 */
-	for (i = 0; i < BPF_REG_FP; i++)
-		cur->frame[cur->curframe]->regs[i].live = REG_LIVE_NONE;
+	for (j = 0; j <= cur->curframe; j++) {
+		for (i = j < cur->curframe ? BPF_REG_6 : 0; i < BPF_REG_FP; i++)
+			cur->frame[j]->regs[i].parent = &new->frame[j]->regs[i];
+		for (i = 0; i < BPF_REG_FP; i++)
+			cur->frame[j]->regs[i].live = REG_LIVE_NONE;
+	}
 
 	/* all stack frames are accessible from callee, clear them all */
 	for (j = 0; j <= cur->curframe; j++) {
@@ -6688,6 +7334,7 @@ static bool reg_type_mismatch_ok(enum bpf_reg_type type)
 	case PTR_TO_SOCK_COMMON_OR_NULL:
 	case PTR_TO_TCP_SOCK:
 	case PTR_TO_TCP_SOCK_OR_NULL:
+	case PTR_TO_XDP_SOCK:
 		return false;
 	default:
 		return true;
@@ -6719,6 +7366,7 @@ static int do_check(struct bpf_verifier_env *env)
 	struct bpf_reg_state *regs;
 	int insn_cnt = env->prog->len;
 	bool do_print_state = false;
+	int prev_insn_idx = -1;
 
 	env->prev_linfo = NULL;
 
@@ -6727,6 +7375,7 @@ static int do_check(struct bpf_verifier_env *env)
 		return -ENOMEM;
 	state->curframe = 0;
 	state->speculative = false;
+	state->branches = 1;
 	state->frame[0] = kzalloc(sizeof(struct bpf_func_state), GFP_KERNEL);
 	if (!state->frame[0]) {
 		kfree(state);
@@ -6743,6 +7392,7 @@ static int do_check(struct bpf_verifier_env *env)
 		u8 class;
 		int err;
 
+		env->prev_insn_idx = prev_insn_idx;
 		if (env->insn_idx >= insn_cnt) {
 			verbose(env, "invalid insn idx %d insn_cnt %d\n",
 				env->insn_idx, insn_cnt);
@@ -6815,6 +7465,7 @@ static int do_check(struct bpf_verifier_env *env)
 
 		regs = cur_regs(env);
 		env->insn_aux_data[env->insn_idx].seen = true;
+		prev_insn_idx = env->insn_idx;
 
 		if (class == BPF_ALU || class == BPF_ALU64) {
 			err = check_alu_op(env, insn);
@@ -6933,6 +7584,7 @@ static int do_check(struct bpf_verifier_env *env)
 		} else if (class == BPF_JMP || class == BPF_JMP32) {
 			u8 opcode = BPF_OP(insn->code);
 
+			env->jmps_processed++;
 			if (opcode == BPF_CALL) {
 				if (BPF_SRC(insn->code) != BPF_K ||
 				    insn->off != 0 ||
@@ -6987,7 +7639,6 @@ static int do_check(struct bpf_verifier_env *env)
 
 				if (state->curframe) {
 					/* exit from nested function */
-					env->prev_insn_idx = env->insn_idx;
 					err = prepare_func_exit(env, &env->insn_idx);
 					if (err)
 						return err;
@@ -7018,7 +7669,8 @@ static int do_check(struct bpf_verifier_env *env)
 				if (err)
 					return err;
 process_bpf_exit:
-				err = pop_stack(env, &env->prev_insn_idx,
+				update_branch_counts(env, env->cur_state);
+				err = pop_stack(env, &prev_insn_idx,
 						&env->insn_idx);
 				if (err < 0) {
 					if (err != -ENOENT)
@@ -7821,6 +8473,9 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
 		case PTR_TO_TCP_SOCK:
 			convert_ctx_access = bpf_tcp_sock_convert_ctx_access;
 			break;
+		case PTR_TO_XDP_SOCK:
+			convert_ctx_access = bpf_xdp_sock_convert_ctx_access;
+			break;
 		default:
 			continue;
 		}
diff --git a/kernel/bpf/xskmap.c b/kernel/bpf/xskmap.c
index 22066c28ba61..ef7338cebd18 100644
--- a/kernel/bpf/xskmap.c
+++ b/kernel/bpf/xskmap.c
@@ -17,8 +17,8 @@ struct xsk_map {
 
 static struct bpf_map *xsk_map_alloc(union bpf_attr *attr)
 {
-	int cpu, err = -EINVAL;
 	struct xsk_map *m;
+	int cpu, err;
 	u64 cost;
 
 	if (!capable(CAP_NET_ADMIN))
@@ -152,6 +152,12 @@ void __xsk_map_flush(struct bpf_map *map)
 
 static void *xsk_map_lookup_elem(struct bpf_map *map, void *key)
 {
+	WARN_ON_ONCE(!rcu_read_lock_held());
+	return __xsk_map_lookup_elem(map, *(u32 *)key);
+}
+
+static void *xsk_map_lookup_elem_sys_only(struct bpf_map *map, void *key)
+{
 	return ERR_PTR(-EOPNOTSUPP);
 }
 
@@ -218,6 +224,7 @@ const struct bpf_map_ops xsk_map_ops = {
 	.map_free = xsk_map_free,
 	.map_get_next_key = xsk_map_get_next_key,
 	.map_lookup_elem = xsk_map_lookup_elem,
+	.map_lookup_elem_sys_only = xsk_map_lookup_elem_sys_only,
 	.map_update_elem = xsk_map_update_elem,
 	.map_delete_elem = xsk_map_delete_elem,
 	.map_check_btf = map_check_no_btf,