summary refs log tree commit diff
path: root/kernel
diff options
context:
space:
mode:
Diffstat (limited to 'kernel')
-rw-r--r--kernel/bpf/sockmap.c45
1 files changed, 22 insertions, 23 deletions
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 26d8a3053407..ce63e5801746 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -236,7 +236,7 @@ static int bpf_tcp_init(struct sock *sk)
 }
 
 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
-static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
 
 static void bpf_tcp_release(struct sock *sk)
 {
@@ -248,7 +248,7 @@ static void bpf_tcp_release(struct sock *sk)
 		goto out;
 
 	if (psock->cork) {
-		free_start_sg(psock->sock, psock->cork);
+		free_start_sg(psock->sock, psock->cork, true);
 		kfree(psock->cork);
 		psock->cork = NULL;
 	}
@@ -330,14 +330,14 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
 	close_fun = psock->save_close;
 
 	if (psock->cork) {
-		free_start_sg(psock->sock, psock->cork);
+		free_start_sg(psock->sock, psock->cork, true);
 		kfree(psock->cork);
 		psock->cork = NULL;
 	}
 
 	list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
 		list_del(&md->list);
-		free_start_sg(psock->sock, md);
+		free_start_sg(psock->sock, md, true);
 		kfree(md);
 	}
 
@@ -570,14 +570,16 @@ static void free_bytes_sg(struct sock *sk, int bytes,
 	md->sg_start = i;
 }
 
-static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
+static int free_sg(struct sock *sk, int start,
+		   struct sk_msg_buff *md, bool charge)
 {
 	struct scatterlist *sg = md->sg_data;
 	int i = start, free = 0;
 
 	while (sg[i].length) {
 		free += sg[i].length;
-		sk_mem_uncharge(sk, sg[i].length);
+		if (charge)
+			sk_mem_uncharge(sk, sg[i].length);
 		if (!md->skb)
 			put_page(sg_page(&sg[i]));
 		sg[i].length = 0;
@@ -594,9 +596,9 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
 	return free;
 }
 
-static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
 {
-	int free = free_sg(sk, md->sg_start, md);
+	int free = free_sg(sk, md->sg_start, md, charge);
 
 	md->sg_start = md->sg_end;
 	return free;
@@ -604,7 +606,7 @@ static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
 
 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
 {
-	return free_sg(sk, md->sg_curr, md);
+	return free_sg(sk, md->sg_curr, md, true);
 }
 
 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
@@ -718,7 +720,7 @@ static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
 		list_add_tail(&r->list, &psock->ingress);
 		sk->sk_data_ready(sk);
 	} else {
-		free_start_sg(sk, r);
+		free_start_sg(sk, r, true);
 		kfree(r);
 	}
 
@@ -752,14 +754,10 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
 		release_sock(sk);
 	}
 	smap_release_sock(psock, sk);
-	if (unlikely(err))
-		goto out;
-	return 0;
+	return err;
 out_rcu:
 	rcu_read_unlock();
-out:
-	free_bytes_sg(NULL, send, md, false);
-	return err;
+	return 0;
 }
 
 static inline void bpf_md_init(struct smap_psock *psock)
@@ -822,7 +820,7 @@ more_data:
 	case __SK_PASS:
 		err = bpf_tcp_push(sk, send, m, flags, true);
 		if (unlikely(err)) {
-			*copied -= free_start_sg(sk, m);
+			*copied -= free_start_sg(sk, m, true);
 			break;
 		}
 
@@ -845,16 +843,17 @@ more_data:
 		lock_sock(sk);
 
 		if (unlikely(err < 0)) {
-			free_start_sg(sk, m);
+			int free = free_start_sg(sk, m, false);
+
 			psock->sg_size = 0;
 			if (!cork)
-				*copied -= send;
+				*copied -= free;
 		} else {
 			psock->sg_size -= send;
 		}
 
 		if (cork) {
-			free_start_sg(sk, m);
+			free_start_sg(sk, m, true);
 			psock->sg_size = 0;
 			kfree(m);
 			m = NULL;
@@ -1121,7 +1120,7 @@ wait_for_memory:
 		err = sk_stream_wait_memory(sk, &timeo);
 		if (err) {
 			if (m && m != psock->cork)
-				free_start_sg(sk, m);
+				free_start_sg(sk, m, true);
 			goto out_err;
 		}
 	}
@@ -1580,13 +1579,13 @@ static void smap_gc_work(struct work_struct *w)
 		bpf_prog_put(psock->bpf_tx_msg);
 
 	if (psock->cork) {
-		free_start_sg(psock->sock, psock->cork);
+		free_start_sg(psock->sock, psock->cork, true);
 		kfree(psock->cork);
 	}
 
 	list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
 		list_del(&md->list);
-		free_start_sg(psock->sock, md);
+		free_start_sg(psock->sock, md, true);
 		kfree(md);
 	}