summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--include/net/mptcp.h9
-rw-r--r--net/mptcp/options.c35
-rw-r--r--net/mptcp/pm.c5
-rw-r--r--net/mptcp/protocol.h12
4 files changed, 48 insertions, 13 deletions
diff --git a/include/net/mptcp.h b/include/net/mptcp.h
index 5694370be3d4..cea69c801595 100644
--- a/include/net/mptcp.h
+++ b/include/net/mptcp.h
@@ -34,6 +34,13 @@ struct mptcp_ext {
 	/* one byte hole */
 };
 
+#define MPTCP_RM_IDS_MAX	8
+
+struct mptcp_rm_list {
+	u8 ids[MPTCP_RM_IDS_MAX];
+	u8 nr;
+};
+
 struct mptcp_out_options {
 #if IS_ENABLED(CONFIG_MPTCP)
 	u16 suboptions;
@@ -48,7 +55,7 @@ struct mptcp_out_options {
 	u8 addr_id;
 	u16 port;
 	u64 ahmac;
-	u8 rm_id;
+	struct mptcp_rm_list rm_list;
 	u8 join_id;
 	u8 backup;
 	u32 nonce;
diff --git a/net/mptcp/options.c b/net/mptcp/options.c
index 444a38681e93..e74d0513187f 100644
--- a/net/mptcp/options.c
+++ b/net/mptcp/options.c
@@ -674,20 +674,25 @@ static bool mptcp_established_options_rm_addr(struct sock *sk,
 {
 	struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
 	struct mptcp_sock *msk = mptcp_sk(subflow->conn);
-	u8 rm_id;
+	struct mptcp_rm_list rm_list;
+	int i, len;
 
 	if (!mptcp_pm_should_rm_signal(msk) ||
-	    !(mptcp_pm_rm_addr_signal(msk, remaining, &rm_id)))
+	    !(mptcp_pm_rm_addr_signal(msk, remaining, &rm_list)))
 		return false;
 
-	if (remaining < TCPOLEN_MPTCP_RM_ADDR_BASE)
+	len = mptcp_rm_addr_len(&rm_list);
+	if (len < 0)
+		return false;
+	if (remaining < len)
 		return false;
 
-	*size = TCPOLEN_MPTCP_RM_ADDR_BASE;
+	*size = len;
 	opts->suboptions |= OPTION_MPTCP_RM_ADDR;
-	opts->rm_id = rm_id;
+	opts->rm_list = rm_list;
 
-	pr_debug("rm_id=%d", opts->rm_id);
+	for (i = 0; i < opts->rm_list.nr; i++)
+		pr_debug("rm_list_ids[%d]=%d", i, opts->rm_list.ids[i]);
 
 	return true;
 }
@@ -1217,9 +1222,23 @@ mp_capable_done:
 	}
 
 	if (OPTION_MPTCP_RM_ADDR & opts->suboptions) {
+		u8 i = 1;
+
 		*ptr++ = mptcp_option(MPTCPOPT_RM_ADDR,
-				      TCPOLEN_MPTCP_RM_ADDR_BASE,
-				      0, opts->rm_id);
+				      TCPOLEN_MPTCP_RM_ADDR_BASE + opts->rm_list.nr,
+				      0, opts->rm_list.ids[0]);
+
+		while (i < opts->rm_list.nr) {
+			u8 id1, id2, id3, id4;
+
+			id1 = opts->rm_list.ids[i];
+			id2 = i + 1 < opts->rm_list.nr ? opts->rm_list.ids[i + 1] : TCPOPT_NOP;
+			id3 = i + 2 < opts->rm_list.nr ? opts->rm_list.ids[i + 2] : TCPOPT_NOP;
+			id4 = i + 3 < opts->rm_list.nr ? opts->rm_list.ids[i + 3] : TCPOPT_NOP;
+			put_unaligned_be32(id1 << 24 | id2 << 16 | id3 << 8 | id4, ptr);
+			ptr += 1;
+			i += 4;
+		}
 	}
 
 	if (OPTION_MPTCP_PRIO & opts->suboptions) {
diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c
index 6fd4b2c1b076..0654c86cd5ff 100644
--- a/net/mptcp/pm.c
+++ b/net/mptcp/pm.c
@@ -258,7 +258,7 @@ out_unlock:
 }
 
 bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
-			     u8 *rm_id)
+			     struct mptcp_rm_list *rm_list)
 {
 	int ret = false;
 
@@ -271,7 +271,8 @@ bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
 	if (remaining < TCPOLEN_MPTCP_RM_ADDR_BASE)
 		goto out_unlock;
 
-	*rm_id = msk->pm.rm_id;
+	rm_list->ids[0] = msk->pm.rm_id;
+	rm_list->nr = 1;
 	WRITE_ONCE(msk->pm.addr_signal, 0);
 	ret = true;
 
diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h
index e21a5bc36cf0..c896bcf3e70f 100644
--- a/net/mptcp/protocol.h
+++ b/net/mptcp/protocol.h
@@ -61,7 +61,7 @@
 #define TCPOLEN_MPTCP_ADD_ADDR6_BASE_PORT	22
 #define TCPOLEN_MPTCP_PORT_LEN		2
 #define TCPOLEN_MPTCP_PORT_ALIGN	2
-#define TCPOLEN_MPTCP_RM_ADDR_BASE	4
+#define TCPOLEN_MPTCP_RM_ADDR_BASE	3
 #define TCPOLEN_MPTCP_PRIO		3
 #define TCPOLEN_MPTCP_PRIO_ALIGN	4
 #define TCPOLEN_MPTCP_FASTCLOSE		12
@@ -709,10 +709,18 @@ static inline unsigned int mptcp_add_addr_len(int family, bool echo, bool port)
 	return len;
 }
 
+static inline int mptcp_rm_addr_len(const struct mptcp_rm_list *rm_list)
+{
+	if (rm_list->nr == 0 || rm_list->nr > MPTCP_RM_IDS_MAX)
+		return -EINVAL;
+
+	return TCPOLEN_MPTCP_RM_ADDR_BASE + roundup(rm_list->nr - 1, 4) + 1;
+}
+
 bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
 			      struct mptcp_addr_info *saddr, bool *echo, bool *port);
 bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
-			     u8 *rm_id);
+			     struct mptcp_rm_list *rm_list);
 int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc);
 
 void __init mptcp_pm_nl_init(void);