// SPDX-License-Identifier: GPL-2.0 /* Multipath TCP * * Copyright (c) 2017 - 2019, Intel Corporation. */ #define pr_fmt(fmt) "MPTCP: " fmt #include #include #include #include #include #include #include #include #include #include "protocol.h" #define MPTCP_SAME_STATE TCP_MAX_STATES /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not * completed yet or has failed, return the subflow socket. * Otherwise return NULL. */ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) { if (!msk->subflow || mptcp_subflow_ctx(msk->subflow->sk)->fourth_ack) return NULL; return msk->subflow; } /* if msk has a single subflow, and the mp_capable handshake is failed, * return it. * Otherwise returns NULL */ static struct socket *__mptcp_tcp_fallback(const struct mptcp_sock *msk) { struct socket *ssock = __mptcp_nmpc_socket(msk); sock_owned_by_me((const struct sock *)msk); if (!ssock || sk_is_mptcp(ssock->sk)) return NULL; return ssock; } static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk) { return ((struct sock *)msk)->sk_state == TCP_CLOSE; } static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) { struct mptcp_subflow_context *subflow; struct sock *sk = (struct sock *)msk; struct socket *ssock; int err; ssock = __mptcp_nmpc_socket(msk); if (ssock) goto set_state; if (!__mptcp_can_create_subflow(msk)) return ERR_PTR(-EINVAL); err = mptcp_subflow_create_socket(sk, &ssock); if (err) return ERR_PTR(err); msk->subflow = ssock; subflow = mptcp_subflow_ctx(ssock->sk); list_add(&subflow->node, &msk->conn_list); subflow->request_mptcp = 1; set_state: if (state != MPTCP_SAME_STATE) inet_sk_state_store(sk, state); return ssock; } static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk) { struct mptcp_subflow_context *subflow; sock_owned_by_me((const struct sock *)msk); mptcp_for_each_subflow(msk, subflow) { return mptcp_subflow_tcp_sock(subflow); } return NULL; } static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) { struct mptcp_sock *msk = mptcp_sk(sk); struct socket *ssock; struct sock *ssk; int ret; if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) return -EOPNOTSUPP; lock_sock(sk); ssock = __mptcp_tcp_fallback(msk); if (ssock) { pr_debug("fallback passthrough"); ret = sock_sendmsg(ssock, msg); release_sock(sk); return ret; } ssk = mptcp_subflow_get(msk); if (!ssk) { release_sock(sk); return -ENOTCONN; } ret = sock_sendmsg(ssk->sk_socket, msg); release_sock(sk); return ret; } static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock, int flags, int *addr_len) { struct mptcp_sock *msk = mptcp_sk(sk); struct socket *ssock; struct sock *ssk; int copied = 0; if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT)) return -EOPNOTSUPP; lock_sock(sk); ssock = __mptcp_tcp_fallback(msk); if (ssock) { pr_debug("fallback-read subflow=%p", mptcp_subflow_ctx(ssock->sk)); copied = sock_recvmsg(ssock, msg, flags); release_sock(sk); return copied; } ssk = mptcp_subflow_get(msk); if (!ssk) { release_sock(sk); return -ENOTCONN; } copied = sock_recvmsg(ssk->sk_socket, msg, flags); release_sock(sk); return copied; } /* subflow sockets can be either outgoing (connect) or incoming * (accept). * * Outgoing subflows use in-kernel sockets. * Incoming subflows do not have their own 'struct socket' allocated, * so we need to use tcp_close() after detaching them from the mptcp * parent socket. */ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, struct mptcp_subflow_context *subflow, long timeout) { struct socket *sock = READ_ONCE(ssk->sk_socket); list_del(&subflow->node); if (sock && sock != sk->sk_socket) { /* outgoing subflow */ sock_release(sock); } else { /* incoming subflow */ tcp_close(ssk, timeout); } } static int mptcp_init_sock(struct sock *sk) { struct mptcp_sock *msk = mptcp_sk(sk); INIT_LIST_HEAD(&msk->conn_list); return 0; } static void mptcp_close(struct sock *sk, long timeout) { struct mptcp_subflow_context *subflow, *tmp; struct mptcp_sock *msk = mptcp_sk(sk); inet_sk_state_store(sk, TCP_CLOSE); lock_sock(sk); list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { struct sock *ssk = mptcp_subflow_tcp_sock(subflow); __mptcp_close_ssk(sk, ssk, subflow, timeout); } release_sock(sk); sk_common_release(sk); } static int mptcp_get_port(struct sock *sk, unsigned short snum) { struct mptcp_sock *msk = mptcp_sk(sk); struct socket *ssock; ssock = __mptcp_nmpc_socket(msk); pr_debug("msk=%p, subflow=%p", msk, ssock); if (WARN_ON_ONCE(!ssock)) return -EINVAL; return inet_csk_get_port(ssock->sk, snum); } void mptcp_finish_connect(struct sock *ssk) { struct mptcp_subflow_context *subflow; struct mptcp_sock *msk; struct sock *sk; subflow = mptcp_subflow_ctx(ssk); if (!subflow->mp_capable) return; sk = subflow->conn; msk = mptcp_sk(sk); /* the socket is not connected yet, no msk/subflow ops can access/race * accessing the field below */ WRITE_ONCE(msk->remote_key, subflow->remote_key); WRITE_ONCE(msk->local_key, subflow->local_key); } static struct proto mptcp_prot = { .name = "MPTCP", .owner = THIS_MODULE, .init = mptcp_init_sock, .close = mptcp_close, .accept = inet_csk_accept, .shutdown = tcp_shutdown, .sendmsg = mptcp_sendmsg, .recvmsg = mptcp_recvmsg, .hash = inet_hash, .unhash = inet_unhash, .get_port = mptcp_get_port, .obj_size = sizeof(struct mptcp_sock), .no_autobind = true, }; static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) { struct mptcp_sock *msk = mptcp_sk(sock->sk); struct socket *ssock; int err = -ENOTSUPP; if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now return err; lock_sock(sock->sk); ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); if (IS_ERR(ssock)) { err = PTR_ERR(ssock); goto unlock; } err = ssock->ops->bind(ssock, uaddr, addr_len); unlock: release_sock(sock->sk); return err; } static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags) { struct mptcp_sock *msk = mptcp_sk(sock->sk); struct socket *ssock; int err; lock_sock(sock->sk); ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); if (IS_ERR(ssock)) { err = PTR_ERR(ssock); goto unlock; } err = ssock->ops->connect(ssock, uaddr, addr_len, flags); inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); unlock: release_sock(sock->sk); return err; } static __poll_t mptcp_poll(struct file *file, struct socket *sock, struct poll_table_struct *wait) { __poll_t mask = 0; return mask; } static struct proto_ops mptcp_stream_ops; static struct inet_protosw mptcp_protosw = { .type = SOCK_STREAM, .protocol = IPPROTO_MPTCP, .prot = &mptcp_prot, .ops = &mptcp_stream_ops, .flags = INET_PROTOSW_ICSK, }; void __init mptcp_init(void) { mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; mptcp_stream_ops = inet_stream_ops; mptcp_stream_ops.bind = mptcp_bind; mptcp_stream_ops.connect = mptcp_stream_connect; mptcp_stream_ops.poll = mptcp_poll; mptcp_subflow_init(); if (proto_register(&mptcp_prot, 1) != 0) panic("Failed to register MPTCP proto.\n"); inet_register_protosw(&mptcp_protosw); } #if IS_ENABLED(CONFIG_MPTCP_IPV6) static struct proto_ops mptcp_v6_stream_ops; static struct proto mptcp_v6_prot; static struct inet_protosw mptcp_v6_protosw = { .type = SOCK_STREAM, .protocol = IPPROTO_MPTCP, .prot = &mptcp_v6_prot, .ops = &mptcp_v6_stream_ops, .flags = INET_PROTOSW_ICSK, }; int mptcpv6_init(void) { int err; mptcp_v6_prot = mptcp_prot; strcpy(mptcp_v6_prot.name, "MPTCPv6"); mptcp_v6_prot.slab = NULL; mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) + sizeof(struct ipv6_pinfo); err = proto_register(&mptcp_v6_prot, 1); if (err) return err; mptcp_v6_stream_ops = inet6_stream_ops; mptcp_v6_stream_ops.bind = mptcp_bind; mptcp_v6_stream_ops.connect = mptcp_stream_connect; mptcp_v6_stream_ops.poll = mptcp_poll; err = inet6_register_protosw(&mptcp_v6_protosw); if (err) proto_unregister(&mptcp_v6_prot); return err; } #endif