summary refs log tree commit diff
path: root/net/l2tp
diff options
context:
space:
mode:
Diffstat (limited to 'net/l2tp')
-rw-r--r--net/l2tp/l2tp_core.c83
-rw-r--r--net/l2tp/l2tp_core.h37
-rw-r--r--net/l2tp/l2tp_debugfs.c4
-rw-r--r--net/l2tp/l2tp_eth.c106
-rw-r--r--net/l2tp/l2tp_ip.c4
-rw-r--r--net/l2tp/l2tp_ip6.c4
-rw-r--r--net/l2tp/l2tp_netlink.c24
-rw-r--r--net/l2tp/l2tp_ppp.c320
8 files changed, 278 insertions, 304 deletions
diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index 02d61101b108..115918ad8eca 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -100,8 +100,6 @@ struct l2tp_skb_cb {
 
 #define L2TP_SKB_CB(skb)	((struct l2tp_skb_cb *) &skb->cb[sizeof(struct inet_skb_parm)])
 
-static atomic_t l2tp_tunnel_count;
-static atomic_t l2tp_session_count;
 static struct workqueue_struct *l2tp_wq;
 
 /* per-net private data for this module */
@@ -216,12 +214,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
 }
 EXPORT_SYMBOL_GPL(l2tp_tunnel_get);
 
-/* Lookup a session. A new reference is held on the returned session.
- * Optionally calls session->ref() too if do_ref is true.
- */
+/* Lookup a session. A new reference is held on the returned session. */
 struct l2tp_session *l2tp_session_get(const struct net *net,
 				      struct l2tp_tunnel *tunnel,
-				      u32 session_id, bool do_ref)
+				      u32 session_id)
 {
 	struct hlist_head *session_list;
 	struct l2tp_session *session;
@@ -235,8 +231,6 @@ struct l2tp_session *l2tp_session_get(const struct net *net,
 		hlist_for_each_entry_rcu(session, session_list, global_hlist) {
 			if (session->session_id == session_id) {
 				l2tp_session_inc_refcount(session);
-				if (do_ref && session->ref)
-					session->ref(session);
 				rcu_read_unlock_bh();
 
 				return session;
@@ -252,8 +246,6 @@ struct l2tp_session *l2tp_session_get(const struct net *net,
 	hlist_for_each_entry(session, session_list, hlist) {
 		if (session->session_id == session_id) {
 			l2tp_session_inc_refcount(session);
-			if (do_ref && session->ref)
-				session->ref(session);
 			read_unlock_bh(&tunnel->hlist_lock);
 
 			return session;
@@ -265,8 +257,7 @@ struct l2tp_session *l2tp_session_get(const struct net *net,
 }
 EXPORT_SYMBOL_GPL(l2tp_session_get);
 
-struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth,
-					  bool do_ref)
+struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth)
 {
 	int hash;
 	struct l2tp_session *session;
@@ -277,8 +268,6 @@ struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth,
 		hlist_for_each_entry(session, &tunnel->session_hlist[hash], hlist) {
 			if (++count > nth) {
 				l2tp_session_inc_refcount(session);
-				if (do_ref && session->ref)
-					session->ref(session);
 				read_unlock_bh(&tunnel->hlist_lock);
 				return session;
 			}
@@ -295,8 +284,7 @@ EXPORT_SYMBOL_GPL(l2tp_session_get_nth);
  * This is very inefficient but is only used by management interfaces.
  */
 struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
-						const char *ifname,
-						bool do_ref)
+						const char *ifname)
 {
 	struct l2tp_net *pn = l2tp_pernet(net);
 	int hash;
@@ -307,8 +295,6 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
 		hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
 			if (!strcmp(session->ifname, ifname)) {
 				l2tp_session_inc_refcount(session);
-				if (do_ref && session->ref)
-					session->ref(session);
 				rcu_read_unlock_bh();
 
 				return session;
@@ -322,8 +308,8 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
 }
 EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);
 
-static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
-				      struct l2tp_session *session)
+int l2tp_session_register(struct l2tp_session *session,
+			  struct l2tp_tunnel *tunnel)
 {
 	struct l2tp_session *session_walk;
 	struct hlist_head *g_head;
@@ -380,6 +366,7 @@ err_tlock:
 
 	return err;
 }
+EXPORT_SYMBOL_GPL(l2tp_session_register);
 
 /* Lookup a tunnel by id
  */
@@ -484,9 +471,6 @@ static void l2tp_recv_dequeue_skb(struct l2tp_session *session, struct sk_buff *
 		(*session->recv_skb)(session, skb, L2TP_SKB_CB(skb)->length);
 	else
 		kfree_skb(skb);
-
-	if (session->deref)
-		(*session->deref)(session);
 }
 
 /* Dequeue skbs from the session's reorder_q, subject to packet order.
@@ -515,8 +499,6 @@ start:
 			session->reorder_skip = 1;
 			__skb_unlink(skb, &session->reorder_q);
 			kfree_skb(skb);
-			if (session->deref)
-				(*session->deref)(session);
 			continue;
 		}
 
@@ -689,9 +671,6 @@ discard:
  * a data (not control) frame before coming here. Fields up to the
  * session-id have already been parsed and ptr points to the data
  * after the session-id.
- *
- * session->ref() must have been called prior to l2tp_recv_common().
- * session->deref() will be called automatically after skb is processed.
  */
 void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
 		      unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -858,9 +837,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
 discard:
 	atomic_long_inc(&session->stats.rx_errors);
 	kfree_skb(skb);
-
-	if (session->deref)
-		(*session->deref)(session);
 }
 EXPORT_SYMBOL(l2tp_recv_common);
 
@@ -874,8 +850,6 @@ int l2tp_session_queue_purge(struct l2tp_session *session)
 	while ((skb = skb_dequeue(&session->reorder_q))) {
 		atomic_long_inc(&session->stats.rx_errors);
 		kfree_skb(skb);
-		if (session->deref)
-			(*session->deref)(session);
 	}
 	return 0;
 }
@@ -967,13 +941,10 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
 	}
 
 	/* Find the session context */
-	session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
+	session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id);
 	if (!session || !session->recv_skb) {
-		if (session) {
-			if (session->deref)
-				session->deref(session);
+		if (session)
 			l2tp_session_dec_refcount(session);
-		}
 
 		/* Not found? Pass to userspace to deal with */
 		l2tp_info(tunnel, L2TP_MSG_DATA,
@@ -1274,9 +1245,6 @@ static void l2tp_tunnel_destruct(struct sock *sk)
 	spin_lock_bh(&pn->l2tp_tunnel_list_lock);
 	list_del_rcu(&tunnel->list);
 	spin_unlock_bh(&pn->l2tp_tunnel_list_lock);
-	atomic_dec(&l2tp_tunnel_count);
-
-	l2tp_tunnel_closeall(tunnel);
 
 	tunnel->sock = NULL;
 	l2tp_tunnel_dec_refcount(tunnel);
@@ -1317,9 +1285,6 @@ again:
 			if (test_and_set_bit(0, &session->dead))
 				goto again;
 
-			if (session->ref != NULL)
-				(*session->ref)(session);
-
 			write_unlock_bh(&tunnel->hlist_lock);
 
 			__l2tp_session_unhash(session);
@@ -1328,9 +1293,6 @@ again:
 			if (session->session_close != NULL)
 				(*session->session_close)(session);
 
-			if (session->deref != NULL)
-				(*session->deref)(session);
-
 			l2tp_session_dec_refcount(session);
 
 			write_lock_bh(&tunnel->hlist_lock);
@@ -1661,7 +1623,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32
 
 	/* Add tunnel to our list */
 	INIT_LIST_HEAD(&tunnel->list);
-	atomic_inc(&l2tp_tunnel_count);
 
 	/* Bump the reference count. The tunnel context is deleted
 	 * only when this drops to zero. Must be done before list insertion
@@ -1707,8 +1668,6 @@ void l2tp_session_free(struct l2tp_session *session)
 
 	if (tunnel) {
 		BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC);
-		if (session->session_id != 0)
-			atomic_dec(&l2tp_session_count);
 		sock_put(tunnel->sock);
 		session->tunnel = NULL;
 		l2tp_tunnel_dec_refcount(tunnel);
@@ -1754,15 +1713,13 @@ int l2tp_session_delete(struct l2tp_session *session)
 	if (test_and_set_bit(0, &session->dead))
 		return 0;
 
-	if (session->ref)
-		(*session->ref)(session);
 	__l2tp_session_unhash(session);
 	l2tp_session_queue_purge(session);
 	if (session->session_close != NULL)
 		(*session->session_close)(session);
-	if (session->deref)
-		(*session->deref)(session);
+
 	l2tp_session_dec_refcount(session);
+
 	return 0;
 }
 EXPORT_SYMBOL_GPL(l2tp_session_delete);
@@ -1788,7 +1745,6 @@ EXPORT_SYMBOL_GPL(l2tp_session_set_header_len);
 struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg)
 {
 	struct l2tp_session *session;
-	int err;
 
 	session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL);
 	if (session != NULL) {
@@ -1846,17 +1802,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
 
 		refcount_set(&session->ref_count, 1);
 
-		err = l2tp_session_add_to_tunnel(tunnel, session);
-		if (err) {
-			kfree(session);
-
-			return ERR_PTR(err);
-		}
-
-		/* Ignore management session in session count value */
-		if (session->session_id != 0)
-			atomic_inc(&l2tp_session_count);
-
 		return session;
 	}
 
@@ -1888,15 +1833,19 @@ static __net_exit void l2tp_exit_net(struct net *net)
 {
 	struct l2tp_net *pn = l2tp_pernet(net);
 	struct l2tp_tunnel *tunnel = NULL;
+	int hash;
 
 	rcu_read_lock_bh();
 	list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
-		(void)l2tp_tunnel_delete(tunnel);
+		l2tp_tunnel_delete(tunnel);
 	}
 	rcu_read_unlock_bh();
 
 	flush_workqueue(l2tp_wq);
 	rcu_barrier();
+
+	for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++)
+		WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash]));
 }
 
 static struct pernet_operations l2tp_net_ops = {
diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h
index 67c79d9b5c6c..9534e16965cc 100644
--- a/net/l2tp/l2tp_core.h
+++ b/net/l2tp/l2tp_core.h
@@ -129,8 +129,6 @@ struct l2tp_session {
 	int (*build_header)(struct l2tp_session *session, void *buf);
 	void (*recv_skb)(struct l2tp_session *session, struct sk_buff *skb, int data_len);
 	void (*session_close)(struct l2tp_session *session);
-	void (*ref)(struct l2tp_session *session);
-	void (*deref)(struct l2tp_session *session);
 #if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
 	void (*show)(struct seq_file *m, void *priv);
 #endif
@@ -245,12 +243,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id);
 
 struct l2tp_session *l2tp_session_get(const struct net *net,
 				      struct l2tp_tunnel *tunnel,
-				      u32 session_id, bool do_ref);
-struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth,
-					  bool do_ref);
+				      u32 session_id);
+struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth);
 struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net,
-						const char *ifname,
-						bool do_ref);
+						const char *ifname);
 struct l2tp_tunnel *l2tp_tunnel_find(const struct net *net, u32 tunnel_id);
 struct l2tp_tunnel *l2tp_tunnel_find_nth(const struct net *net, int nth);
 
@@ -263,6 +259,9 @@ struct l2tp_session *l2tp_session_create(int priv_size,
 					 struct l2tp_tunnel *tunnel,
 					 u32 session_id, u32 peer_session_id,
 					 struct l2tp_session_cfg *cfg);
+int l2tp_session_register(struct l2tp_session *session,
+			  struct l2tp_tunnel *tunnel);
+
 void __l2tp_session_unhash(struct l2tp_session *session);
 int l2tp_session_delete(struct l2tp_session *session);
 void l2tp_session_free(struct l2tp_session *session);
@@ -295,37 +294,17 @@ static inline void l2tp_tunnel_dec_refcount(struct l2tp_tunnel *tunnel)
 /* Session reference counts. Incremented when code obtains a reference
  * to a session.
  */
-static inline void l2tp_session_inc_refcount_1(struct l2tp_session *session)
+static inline void l2tp_session_inc_refcount(struct l2tp_session *session)
 {
 	refcount_inc(&session->ref_count);
 }
 
-static inline void l2tp_session_dec_refcount_1(struct l2tp_session *session)
+static inline void l2tp_session_dec_refcount(struct l2tp_session *session)
 {
 	if (refcount_dec_and_test(&session->ref_count))
 		l2tp_session_free(session);
 }
 
-#ifdef L2TP_REFCNT_DEBUG
-#define l2tp_session_inc_refcount(_s)					\
-do {									\
-	pr_debug("l2tp_session_inc_refcount: %s:%d %s: cnt=%d\n",	\
-		 __func__, __LINE__, (_s)->name,			\
-		 refcount_read(&_s->ref_count));			\
-	l2tp_session_inc_refcount_1(_s);				\
-} while (0)
-#define l2tp_session_dec_refcount(_s)					\
-do {									\
-	pr_debug("l2tp_session_dec_refcount: %s:%d %s: cnt=%d\n",	\
-		 __func__, __LINE__, (_s)->name,			\
-		 refcount_read(&_s->ref_count));			\
-	l2tp_session_dec_refcount_1(_s);				\
-} while (0)
-#else
-#define l2tp_session_inc_refcount(s) l2tp_session_inc_refcount_1(s)
-#define l2tp_session_dec_refcount(s) l2tp_session_dec_refcount_1(s)
-#endif
-
 #define l2tp_printk(ptr, type, func, fmt, ...)				\
 do {									\
 	if (((ptr)->debug) & (type))					\
diff --git a/net/l2tp/l2tp_debugfs.c b/net/l2tp/l2tp_debugfs.c
index 53bae54c4d6e..eb69411bcb47 100644
--- a/net/l2tp/l2tp_debugfs.c
+++ b/net/l2tp/l2tp_debugfs.c
@@ -53,7 +53,7 @@ static void l2tp_dfs_next_tunnel(struct l2tp_dfs_seq_data *pd)
 
 static void l2tp_dfs_next_session(struct l2tp_dfs_seq_data *pd)
 {
-	pd->session = l2tp_session_get_nth(pd->tunnel, pd->session_idx, true);
+	pd->session = l2tp_session_get_nth(pd->tunnel, pd->session_idx);
 	pd->session_idx++;
 
 	if (pd->session == NULL) {
@@ -241,8 +241,6 @@ static int l2tp_dfs_seq_show(struct seq_file *m, void *v)
 		l2tp_dfs_seq_tunnel_show(m, pd->tunnel);
 	} else {
 		l2tp_dfs_seq_session_show(m, pd->session);
-		if (pd->session->deref)
-			pd->session->deref(pd->session);
 		l2tp_session_dec_refcount(pd->session);
 	}
 
diff --git a/net/l2tp/l2tp_eth.c b/net/l2tp/l2tp_eth.c
index 014a7bc2a872..5c366ecfa1cb 100644
--- a/net/l2tp/l2tp_eth.c
+++ b/net/l2tp/l2tp_eth.c
@@ -41,8 +41,6 @@
 
 /* via netdev_priv() */
 struct l2tp_eth {
-	struct net_device	*dev;
-	struct sock		*tunnel_sock;
 	struct l2tp_session	*session;
 	atomic_long_t		tx_bytes;
 	atomic_long_t		tx_packets;
@@ -54,15 +52,12 @@ struct l2tp_eth {
 
 /* via l2tp_session_priv() */
 struct l2tp_eth_sess {
-	struct net_device	*dev;
+	struct net_device __rcu *dev;
 };
 
 
 static int l2tp_eth_dev_init(struct net_device *dev)
 {
-	struct l2tp_eth *priv = netdev_priv(dev);
-
-	priv->dev = dev;
 	eth_hw_addr_random(dev);
 	eth_broadcast_addr(dev->broadcast);
 	netdev_lockdep_set_classes(dev);
@@ -72,7 +67,14 @@ static int l2tp_eth_dev_init(struct net_device *dev)
 
 static void l2tp_eth_dev_uninit(struct net_device *dev)
 {
-	dev_put(dev);
+	struct l2tp_eth *priv = netdev_priv(dev);
+	struct l2tp_eth_sess *spriv;
+
+	spriv = l2tp_session_priv(priv->session);
+	RCU_INIT_POINTER(spriv->dev, NULL);
+	/* No need for synchronize_net() here. We're called by
+	 * unregister_netdev*(), which does the synchronisation for us.
+	 */
 }
 
 static int l2tp_eth_dev_xmit(struct sk_buff *skb, struct net_device *dev)
@@ -130,8 +132,8 @@ static void l2tp_eth_dev_setup(struct net_device *dev)
 static void l2tp_eth_dev_recv(struct l2tp_session *session, struct sk_buff *skb, int data_len)
 {
 	struct l2tp_eth_sess *spriv = l2tp_session_priv(session);
-	struct net_device *dev = spriv->dev;
-	struct l2tp_eth *priv = netdev_priv(dev);
+	struct net_device *dev;
+	struct l2tp_eth *priv;
 
 	if (session->debug & L2TP_MSG_DATA) {
 		unsigned int length;
@@ -155,16 +157,25 @@ static void l2tp_eth_dev_recv(struct l2tp_session *session, struct sk_buff *skb,
 	skb_dst_drop(skb);
 	nf_reset(skb);
 
+	rcu_read_lock();
+	dev = rcu_dereference(spriv->dev);
+	if (!dev)
+		goto error_rcu;
+
+	priv = netdev_priv(dev);
 	if (dev_forward_skb(dev, skb) == NET_RX_SUCCESS) {
 		atomic_long_inc(&priv->rx_packets);
 		atomic_long_add(data_len, &priv->rx_bytes);
 	} else {
 		atomic_long_inc(&priv->rx_errors);
 	}
+	rcu_read_unlock();
+
 	return;
 
+error_rcu:
+	rcu_read_unlock();
 error:
-	atomic_long_inc(&priv->rx_errors);
 	kfree_skb(skb);
 }
 
@@ -175,11 +186,15 @@ static void l2tp_eth_delete(struct l2tp_session *session)
 
 	if (session) {
 		spriv = l2tp_session_priv(session);
-		dev = spriv->dev;
+
+		rtnl_lock();
+		dev = rtnl_dereference(spriv->dev);
 		if (dev) {
-			unregister_netdev(dev);
-			spriv->dev = NULL;
+			unregister_netdevice(dev);
+			rtnl_unlock();
 			module_put(THIS_MODULE);
+		} else {
+			rtnl_unlock();
 		}
 	}
 }
@@ -189,9 +204,20 @@ static void l2tp_eth_show(struct seq_file *m, void *arg)
 {
 	struct l2tp_session *session = arg;
 	struct l2tp_eth_sess *spriv = l2tp_session_priv(session);
-	struct net_device *dev = spriv->dev;
+	struct net_device *dev;
+
+	rcu_read_lock();
+	dev = rcu_dereference(spriv->dev);
+	if (!dev) {
+		rcu_read_unlock();
+		return;
+	}
+	dev_hold(dev);
+	rcu_read_unlock();
 
 	seq_printf(m, "   interface %s\n", dev->name);
+
+	dev_put(dev);
 }
 #endif
 
@@ -268,14 +294,14 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel,
 				      peer_session_id, cfg);
 	if (IS_ERR(session)) {
 		rc = PTR_ERR(session);
-		goto out;
+		goto err;
 	}
 
 	dev = alloc_netdev(sizeof(*priv), name, name_assign_type,
 			   l2tp_eth_dev_setup);
 	if (!dev) {
 		rc = -ENOMEM;
-		goto out_del_session;
+		goto err_sess;
 	}
 
 	dev_net_set(dev, net);
@@ -284,10 +310,8 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel,
 	l2tp_eth_adjust_mtu(tunnel, session, dev);
 
 	priv = netdev_priv(dev);
-	priv->dev = dev;
 	priv->session = session;
 
-	priv->tunnel_sock = tunnel->sock;
 	session->recv_skb = l2tp_eth_dev_recv;
 	session->session_close = l2tp_eth_delete;
 #if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
@@ -295,26 +319,48 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel,
 #endif
 
 	spriv = l2tp_session_priv(session);
-	spriv->dev = dev;
 
-	rc = register_netdev(dev);
-	if (rc < 0)
-		goto out_del_dev;
+	l2tp_session_inc_refcount(session);
+
+	rtnl_lock();
+
+	/* Register both device and session while holding the rtnl lock. This
+	 * ensures that l2tp_eth_delete() will see that there's a device to
+	 * unregister, even if it happened to run before we assign spriv->dev.
+	 */
+	rc = l2tp_session_register(session, tunnel);
+	if (rc < 0) {
+		rtnl_unlock();
+		goto err_sess_dev;
+	}
+
+	rc = register_netdevice(dev);
+	if (rc < 0) {
+		rtnl_unlock();
+		l2tp_session_delete(session);
+		l2tp_session_dec_refcount(session);
+		free_netdev(dev);
+
+		return rc;
+	}
 
-	__module_get(THIS_MODULE);
-	/* Must be done after register_netdev() */
 	strlcpy(session->ifname, dev->name, IFNAMSIZ);
+	rcu_assign_pointer(spriv->dev, dev);
 
-	dev_hold(dev);
+	rtnl_unlock();
+
+	l2tp_session_dec_refcount(session);
+
+	__module_get(THIS_MODULE);
 
 	return 0;
 
-out_del_dev:
+err_sess_dev:
+	l2tp_session_dec_refcount(session);
 	free_netdev(dev);
-	spriv->dev = NULL;
-out_del_session:
-	l2tp_session_delete(session);
-out:
+err_sess:
+	kfree(session);
+err:
 	return rc;
 }
 
diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c
index e4280b6568b4..ff61124fdf59 100644
--- a/net/l2tp/l2tp_ip.c
+++ b/net/l2tp/l2tp_ip.c
@@ -144,7 +144,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
 	}
 
 	/* Ok, this is a data packet. Lookup the session. */
-	session = l2tp_session_get(net, NULL, session_id, true);
+	session = l2tp_session_get(net, NULL, session_id);
 	if (!session)
 		goto discard;
 
@@ -199,8 +199,6 @@ pass_up:
 	return sk_receive_skb(sk, skb, 1);
 
 discard_sess:
-	if (session->deref)
-		session->deref(session);
 	l2tp_session_dec_refcount(session);
 	goto discard;
 
diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c
index 8bcaa975b432..192344688c06 100644
--- a/net/l2tp/l2tp_ip6.c
+++ b/net/l2tp/l2tp_ip6.c
@@ -157,7 +157,7 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
 	}
 
 	/* Ok, this is a data packet. Lookup the session. */
-	session = l2tp_session_get(net, NULL, session_id, true);
+	session = l2tp_session_get(net, NULL, session_id);
 	if (!session)
 		goto discard;
 
@@ -213,8 +213,6 @@ pass_up:
 	return sk_receive_skb(sk, skb, 1);
 
 discard_sess:
-	if (session->deref)
-		session->deref(session);
 	l2tp_session_dec_refcount(session);
 	goto discard;
 
diff --git a/net/l2tp/l2tp_netlink.c b/net/l2tp/l2tp_netlink.c
index 7135f4645d3a..a1f24fb2be98 100644
--- a/net/l2tp/l2tp_netlink.c
+++ b/net/l2tp/l2tp_netlink.c
@@ -48,8 +48,7 @@ static int l2tp_nl_session_send(struct sk_buff *skb, u32 portid, u32 seq,
 /* Accessed under genl lock */
 static const struct l2tp_nl_cmd_ops *l2tp_nl_cmd_ops[__L2TP_PWTYPE_MAX];
 
-static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info,
-						bool do_ref)
+static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info)
 {
 	u32 tunnel_id;
 	u32 session_id;
@@ -60,15 +59,14 @@ static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info,
 
 	if (info->attrs[L2TP_ATTR_IFNAME]) {
 		ifname = nla_data(info->attrs[L2TP_ATTR_IFNAME]);
-		session = l2tp_session_get_by_ifname(net, ifname, do_ref);
+		session = l2tp_session_get_by_ifname(net, ifname);
 	} else if ((info->attrs[L2TP_ATTR_SESSION_ID]) &&
 		   (info->attrs[L2TP_ATTR_CONN_ID])) {
 		tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
 		session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
 		tunnel = l2tp_tunnel_get(net, tunnel_id);
 		if (tunnel) {
-			session = l2tp_session_get(net, tunnel, session_id,
-						   do_ref);
+			session = l2tp_session_get(net, tunnel, session_id);
 			l2tp_tunnel_dec_refcount(tunnel);
 		}
 	}
@@ -282,7 +280,7 @@ static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info
 	l2tp_tunnel_notify(&l2tp_nl_family, info,
 			   tunnel, L2TP_CMD_TUNNEL_DELETE);
 
-	(void) l2tp_tunnel_delete(tunnel);
+	l2tp_tunnel_delete(tunnel);
 
 	l2tp_tunnel_dec_refcount(tunnel);
 
@@ -406,7 +404,7 @@ static int l2tp_nl_tunnel_send(struct sk_buff *skb, u32 portid, u32 seq, int fla
 		if (nla_put_u16(skb, L2TP_ATTR_UDP_SPORT, ntohs(inet->inet_sport)) ||
 		    nla_put_u16(skb, L2TP_ATTR_UDP_DPORT, ntohs(inet->inet_dport)))
 			goto nla_put_failure;
-		/* NOBREAK */
+		/* fall through  */
 	case L2TP_ENCAPTYPE_IP:
 #if IS_ENABLED(CONFIG_IPV6)
 		if (np) {
@@ -649,7 +647,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
 							   &cfg);
 
 	if (ret >= 0) {
-		session = l2tp_session_get(net, tunnel, session_id, false);
+		session = l2tp_session_get(net, tunnel, session_id);
 		if (session) {
 			ret = l2tp_session_notify(&l2tp_nl_family, info, session,
 						  L2TP_CMD_SESSION_CREATE);
@@ -669,7 +667,7 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
 	struct l2tp_session *session;
 	u16 pw_type;
 
-	session = l2tp_nl_session_get(info, true);
+	session = l2tp_nl_session_get(info);
 	if (session == NULL) {
 		ret = -ENODEV;
 		goto out;
@@ -683,8 +681,6 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
 		if (l2tp_nl_cmd_ops[pw_type] && l2tp_nl_cmd_ops[pw_type]->session_delete)
 			ret = (*l2tp_nl_cmd_ops[pw_type]->session_delete)(session);
 
-	if (session->deref)
-		session->deref(session);
 	l2tp_session_dec_refcount(session);
 
 out:
@@ -696,7 +692,7 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
 	int ret = 0;
 	struct l2tp_session *session;
 
-	session = l2tp_nl_session_get(info, false);
+	session = l2tp_nl_session_get(info);
 	if (session == NULL) {
 		ret = -ENODEV;
 		goto out;
@@ -828,7 +824,7 @@ static int l2tp_nl_cmd_session_get(struct sk_buff *skb, struct genl_info *info)
 	struct sk_buff *msg;
 	int ret;
 
-	session = l2tp_nl_session_get(info, false);
+	session = l2tp_nl_session_get(info);
 	if (session == NULL) {
 		ret = -ENODEV;
 		goto err;
@@ -874,7 +870,7 @@ static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback
 				goto out;
 		}
 
-		session = l2tp_session_get_nth(tunnel, si, false);
+		session = l2tp_session_get_nth(tunnel, si);
 		if (session == NULL) {
 			ti++;
 			tunnel = NULL;
diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c
index 0c2738349442..b412fc3351dc 100644
--- a/net/l2tp/l2tp_ppp.c
+++ b/net/l2tp/l2tp_ppp.c
@@ -122,10 +122,11 @@
 struct pppol2tp_session {
 	int			owner;		/* pid that opened the socket */
 
-	struct sock		*sock;		/* Pointer to the session
+	struct mutex		sk_lock;	/* Protects .sk */
+	struct sock __rcu	*sk;		/* Pointer to the session
 						 * PPPoX socket */
-	struct sock		*tunnel_sock;	/* Pointer to the tunnel UDP
-						 * socket */
+	struct sock		*__sk;		/* Copy of .sk, for cleanup */
+	struct rcu_head		rcu;		/* For asynchronous release */
 	int			flags;		/* accessed by PPPIOCGFLAGS.
 						 * Unused. */
 };
@@ -138,6 +139,24 @@ static const struct ppp_channel_ops pppol2tp_chan_ops = {
 
 static const struct proto_ops pppol2tp_ops;
 
+/* Retrieves the pppol2tp socket associated to a session.
+ * A reference is held on the returned socket, so this function must be paired
+ * with sock_put().
+ */
+static struct sock *pppol2tp_session_get_sock(struct l2tp_session *session)
+{
+	struct pppol2tp_session *ps = l2tp_session_priv(session);
+	struct sock *sk;
+
+	rcu_read_lock();
+	sk = rcu_dereference(ps->sk);
+	if (sk)
+		sock_hold(sk);
+	rcu_read_unlock();
+
+	return sk;
+}
+
 /* Helpers to obtain tunnel/session contexts from sockets.
  */
 static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
@@ -224,7 +243,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
 	/* If the socket is bound, send it in to PPP's input queue. Otherwise
 	 * queue it on the session socket.
 	 */
-	sk = ps->sock;
+	rcu_read_lock();
+	sk = rcu_dereference(ps->sk);
 	if (sk == NULL)
 		goto no_sock;
 
@@ -247,30 +267,16 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int
 			kfree_skb(skb);
 		}
 	}
+	rcu_read_unlock();
 
 	return;
 
 no_sock:
+	rcu_read_unlock();
 	l2tp_info(session, L2TP_MSG_DATA, "%s: no socket\n", session->name);
 	kfree_skb(skb);
 }
 
-static void pppol2tp_session_sock_hold(struct l2tp_session *session)
-{
-	struct pppol2tp_session *ps = l2tp_session_priv(session);
-
-	if (ps->sock)
-		sock_hold(ps->sock);
-}
-
-static void pppol2tp_session_sock_put(struct l2tp_session *session)
-{
-	struct pppol2tp_session *ps = l2tp_session_priv(session);
-
-	if (ps->sock)
-		sock_put(ps->sock);
-}
-
 /************************************************************************
  * Transmit handling
  ***********************************************************************/
@@ -287,7 +293,6 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
 	int error;
 	struct l2tp_session *session;
 	struct l2tp_tunnel *tunnel;
-	struct pppol2tp_session *ps;
 	int uhlen;
 
 	error = -ENOTCONN;
@@ -300,10 +305,7 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
 	if (session == NULL)
 		goto error;
 
-	ps = l2tp_session_priv(session);
-	tunnel = l2tp_sock_to_tunnel(ps->tunnel_sock);
-	if (tunnel == NULL)
-		goto error_put_sess;
+	tunnel = session->tunnel;
 
 	uhlen = (tunnel->encap == L2TP_ENCAPTYPE_UDP) ? sizeof(struct udphdr) : 0;
 
@@ -314,7 +316,7 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
 			   2 + total_len, /* 2 bytes for PPP_ALLSTATIONS & PPP_UI */
 			   0, GFP_KERNEL);
 	if (!skb)
-		goto error_put_sess_tun;
+		goto error_put_sess;
 
 	/* Reserve space for headers. */
 	skb_reserve(skb, NET_SKB_PAD);
@@ -332,20 +334,17 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
 	error = memcpy_from_msg(skb_put(skb, total_len), m, total_len);
 	if (error < 0) {
 		kfree_skb(skb);
-		goto error_put_sess_tun;
+		goto error_put_sess;
 	}
 
 	local_bh_disable();
 	l2tp_xmit_skb(session, skb, session->hdr_len);
 	local_bh_enable();
 
-	sock_put(ps->tunnel_sock);
 	sock_put(sk);
 
 	return total_len;
 
-error_put_sess_tun:
-	sock_put(ps->tunnel_sock);
 error_put_sess:
 	sock_put(sk);
 error:
@@ -369,10 +368,8 @@ error:
 static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
 {
 	struct sock *sk = (struct sock *) chan->private;
-	struct sock *sk_tun;
 	struct l2tp_session *session;
 	struct l2tp_tunnel *tunnel;
-	struct pppol2tp_session *ps;
 	int uhlen, headroom;
 
 	if (sock_flag(sk, SOCK_DEAD) || !(sk->sk_state & PPPOX_CONNECTED))
@@ -383,13 +380,7 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
 	if (session == NULL)
 		goto abort;
 
-	ps = l2tp_session_priv(session);
-	sk_tun = ps->tunnel_sock;
-	if (sk_tun == NULL)
-		goto abort_put_sess;
-	tunnel = l2tp_sock_to_tunnel(sk_tun);
-	if (tunnel == NULL)
-		goto abort_put_sess;
+	tunnel = session->tunnel;
 
 	uhlen = (tunnel->encap == L2TP_ENCAPTYPE_UDP) ? sizeof(struct udphdr) : 0;
 	headroom = NET_SKB_PAD +
@@ -398,7 +389,7 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
 		   session->hdr_len +	/* L2TP header */
 		   2;			/* 2 bytes for PPP_ALLSTATIONS & PPP_UI */
 	if (skb_cow_head(skb, headroom))
-		goto abort_put_sess_tun;
+		goto abort_put_sess;
 
 	/* Setup PPP header */
 	__skb_push(skb, 2);
@@ -409,12 +400,10 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
 	l2tp_xmit_skb(session, skb, session->hdr_len);
 	local_bh_enable();
 
-	sock_put(sk_tun);
 	sock_put(sk);
+
 	return 1;
 
-abort_put_sess_tun:
-	sock_put(sk_tun);
 abort_put_sess:
 	sock_put(sk);
 abort:
@@ -431,17 +420,16 @@ abort:
  */
 static void pppol2tp_session_close(struct l2tp_session *session)
 {
-	struct pppol2tp_session *ps = l2tp_session_priv(session);
-	struct sock *sk = ps->sock;
-	struct socket *sock = sk->sk_socket;
+	struct sock *sk;
 
 	BUG_ON(session->magic != L2TP_SESSION_MAGIC);
 
-	if (sock)
-		inet_shutdown(sock, SEND_SHUTDOWN);
-
-	/* Don't let the session go away before our socket does */
-	l2tp_session_inc_refcount(session);
+	sk = pppol2tp_session_get_sock(session);
+	if (sk) {
+		if (sk->sk_socket)
+			inet_shutdown(sk->sk_socket, SEND_SHUTDOWN);
+		sock_put(sk);
+	}
 }
 
 /* Really kill the session socket. (Called from sock_put() if
@@ -461,6 +449,14 @@ static void pppol2tp_session_destruct(struct sock *sk)
 	}
 }
 
+static void pppol2tp_put_sk(struct rcu_head *head)
+{
+	struct pppol2tp_session *ps;
+
+	ps = container_of(head, typeof(*ps), rcu);
+	sock_put(ps->__sk);
+}
+
 /* Called when the PPPoX socket (session) is closed.
  */
 static int pppol2tp_release(struct socket *sock)
@@ -486,11 +482,23 @@ static int pppol2tp_release(struct socket *sock)
 
 	session = pppol2tp_sock_to_session(sk);
 
-	/* Purge any queued data */
 	if (session != NULL) {
-		__l2tp_session_unhash(session);
-		l2tp_session_queue_purge(session);
-		sock_put(sk);
+		struct pppol2tp_session *ps;
+
+		l2tp_session_delete(session);
+
+		ps = l2tp_session_priv(session);
+		mutex_lock(&ps->sk_lock);
+		ps->__sk = rcu_dereference_protected(ps->sk,
+						     lockdep_is_held(&ps->sk_lock));
+		RCU_INIT_POINTER(ps->sk, NULL);
+		mutex_unlock(&ps->sk_lock);
+		call_rcu(&ps->rcu, pppol2tp_put_sk);
+
+		/* Rely on the sock_put() call at the end of the function for
+		 * dropping the reference held by pppol2tp_sock_to_session().
+		 * The last reference will be dropped by pppol2tp_put_sk().
+		 */
 	}
 	release_sock(sk);
 
@@ -557,16 +565,46 @@ out:
 static void pppol2tp_show(struct seq_file *m, void *arg)
 {
 	struct l2tp_session *session = arg;
-	struct pppol2tp_session *ps = l2tp_session_priv(session);
+	struct sock *sk;
 
-	if (ps) {
-		struct pppox_sock *po = pppox_sk(ps->sock);
-		if (po)
-			seq_printf(m, "   interface %s\n", ppp_dev_name(&po->chan));
+	sk = pppol2tp_session_get_sock(session);
+	if (sk) {
+		struct pppox_sock *po = pppox_sk(sk);
+
+		seq_printf(m, "   interface %s\n", ppp_dev_name(&po->chan));
+		sock_put(sk);
 	}
 }
 #endif
 
+static void pppol2tp_session_init(struct l2tp_session *session)
+{
+	struct pppol2tp_session *ps;
+	struct dst_entry *dst;
+
+	session->recv_skb = pppol2tp_recv;
+	session->session_close = pppol2tp_session_close;
+#if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
+	session->show = pppol2tp_show;
+#endif
+
+	ps = l2tp_session_priv(session);
+	mutex_init(&ps->sk_lock);
+	ps->owner = current->pid;
+
+	/* If PMTU discovery was enabled, use the MTU that was discovered */
+	dst = sk_dst_get(session->tunnel->sock);
+	if (dst) {
+		u32 pmtu = dst_mtu(dst);
+
+		if (pmtu) {
+			session->mtu = pmtu - PPPOL2TP_HEADER_OVERHEAD;
+			session->mru = pmtu - PPPOL2TP_HEADER_OVERHEAD;
+		}
+		dst_release(dst);
+	}
+}
+
 /* connect() handler. Attach a PPPoX socket to a tunnel UDP socket
  */
 static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
@@ -578,7 +616,6 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 	struct l2tp_session *session = NULL;
 	struct l2tp_tunnel *tunnel;
 	struct pppol2tp_session *ps;
-	struct dst_entry *dst;
 	struct l2tp_session_cfg cfg = { 0, };
 	int error = 0;
 	u32 tunnel_id, peer_tunnel_id;
@@ -688,7 +725,7 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 	if (tunnel->peer_tunnel_id == 0)
 		tunnel->peer_tunnel_id = peer_tunnel_id;
 
-	session = l2tp_session_get(sock_net(sk), tunnel, session_id, false);
+	session = l2tp_session_get(sock_net(sk), tunnel, session_id);
 	if (session) {
 		drop_refcnt = true;
 		ps = l2tp_session_priv(session);
@@ -696,13 +733,10 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 		/* Using a pre-existing session is fine as long as it hasn't
 		 * been connected yet.
 		 */
-		if (ps->sock) {
-			error = -EEXIST;
-			goto end;
-		}
-
-		/* consistency checks */
-		if (ps->tunnel_sock != tunnel->sock) {
+		mutex_lock(&ps->sk_lock);
+		if (rcu_dereference_protected(ps->sk,
+					      lockdep_is_held(&ps->sk_lock))) {
+			mutex_unlock(&ps->sk_lock);
 			error = -EEXIST;
 			goto end;
 		}
@@ -718,35 +752,19 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 			error = PTR_ERR(session);
 			goto end;
 		}
-	}
-
-	/* Associate session with its PPPoL2TP socket */
-	ps = l2tp_session_priv(session);
-	ps->owner	     = current->pid;
-	ps->sock	     = sk;
-	ps->tunnel_sock = tunnel->sock;
 
-	session->recv_skb	= pppol2tp_recv;
-	session->session_close	= pppol2tp_session_close;
-#if IS_ENABLED(CONFIG_L2TP_DEBUGFS)
-	session->show		= pppol2tp_show;
-#endif
-
-	/* We need to know each time a skb is dropped from the reorder
-	 * queue.
-	 */
-	session->ref = pppol2tp_session_sock_hold;
-	session->deref = pppol2tp_session_sock_put;
-
-	/* If PMTU discovery was enabled, use the MTU that was discovered */
-	dst = sk_dst_get(tunnel->sock);
-	if (dst != NULL) {
-		u32 pmtu = dst_mtu(dst);
+		pppol2tp_session_init(session);
+		ps = l2tp_session_priv(session);
+		l2tp_session_inc_refcount(session);
 
-		if (pmtu != 0)
-			session->mtu = session->mru = pmtu -
-				PPPOL2TP_HEADER_OVERHEAD;
-		dst_release(dst);
+		mutex_lock(&ps->sk_lock);
+		error = l2tp_session_register(session, tunnel);
+		if (error < 0) {
+			mutex_unlock(&ps->sk_lock);
+			kfree(session);
+			goto end;
+		}
+		drop_refcnt = true;
 	}
 
 	/* Special case: if source & dest session_id == 0x0000, this
@@ -771,12 +789,23 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 	po->chan.mtu	 = session->mtu;
 
 	error = ppp_register_net_channel(sock_net(sk), &po->chan);
-	if (error)
+	if (error) {
+		mutex_unlock(&ps->sk_lock);
 		goto end;
+	}
 
 out_no_ppp:
 	/* This is how we get the session context from the socket. */
 	sk->sk_user_data = session;
+	rcu_assign_pointer(ps->sk, sk);
+	mutex_unlock(&ps->sk_lock);
+
+	/* Keep the reference we've grabbed on the session: sk doesn't expect
+	 * the session to disappear. pppol2tp_session_destruct() is responsible
+	 * for dropping it.
+	 */
+	drop_refcnt = false;
+
 	sk->sk_state = PPPOX_CONNECTED;
 	l2tp_info(session, L2TP_MSG_CONTROL, "%s: created\n",
 		  session->name);
@@ -800,12 +829,11 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
 {
 	int error;
 	struct l2tp_session *session;
-	struct pppol2tp_session *ps;
 
 	/* Error if tunnel socket is not prepped */
 	if (!tunnel->sock) {
 		error = -ENOENT;
-		goto out;
+		goto err;
 	}
 
 	/* Default MTU values. */
@@ -820,18 +848,20 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
 				      peer_session_id, cfg);
 	if (IS_ERR(session)) {
 		error = PTR_ERR(session);
-		goto out;
+		goto err;
 	}
 
-	ps = l2tp_session_priv(session);
-	ps->tunnel_sock = tunnel->sock;
+	pppol2tp_session_init(session);
 
-	l2tp_info(session, L2TP_MSG_CONTROL, "%s: created\n",
-		  session->name);
+	error = l2tp_session_register(session, tunnel);
+	if (error < 0)
+		goto err_sess;
 
-	error = 0;
+	return 0;
 
-out:
+err_sess:
+	kfree(session);
+err:
 	return error;
 }
 
@@ -862,9 +892,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
 		goto end;
 
 	pls = l2tp_session_priv(session);
-	tunnel = l2tp_sock_to_tunnel(pls->tunnel_sock);
-	if (tunnel == NULL)
-		goto end_put_sess;
+	tunnel = session->tunnel;
 
 	inet = inet_sk(tunnel->sock);
 	if ((tunnel->version == 2) && (tunnel->sock->sk_family == AF_INET)) {
@@ -944,8 +972,6 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
 	*usockaddr_len = len;
 	error = 0;
 
-	sock_put(pls->tunnel_sock);
-end_put_sess:
 	sock_put(sk);
 end:
 	return error;
@@ -992,12 +1018,10 @@ static int pppol2tp_session_ioctl(struct l2tp_session *session,
 		 "%s: pppol2tp_session_ioctl(cmd=%#x, arg=%#lx)\n",
 		 session->name, cmd, arg);
 
-	sk = ps->sock;
+	sk = pppol2tp_session_get_sock(session);
 	if (!sk)
 		return -EBADR;
 
-	sock_hold(sk);
-
 	switch (cmd) {
 	case SIOCGIFMTU:
 		err = -ENXIO;
@@ -1143,13 +1167,11 @@ static int pppol2tp_tunnel_ioctl(struct l2tp_tunnel *tunnel,
 			/* resend to session ioctl handler */
 			struct l2tp_session *session =
 				l2tp_session_get(sock_net(sk), tunnel,
-						 stats.session_id, true);
+						 stats.session_id);
 
 			if (session) {
 				err = pppol2tp_session_ioctl(session, cmd,
 							     arg);
-				if (session->deref)
-					session->deref(session);
 				l2tp_session_dec_refcount(session);
 			} else {
 				err = -EBADR;
@@ -1188,7 +1210,6 @@ static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd,
 	struct sock *sk = sock->sk;
 	struct l2tp_session *session;
 	struct l2tp_tunnel *tunnel;
-	struct pppol2tp_session *ps;
 	int err;
 
 	if (!sk)
@@ -1212,16 +1233,10 @@ static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd,
 	/* Special case: if session's session_id is zero, treat ioctl as a
 	 * tunnel ioctl
 	 */
-	ps = l2tp_session_priv(session);
 	if ((session->session_id == 0) &&
 	    (session->peer_session_id == 0)) {
-		err = -EBADF;
-		tunnel = l2tp_sock_to_tunnel(ps->tunnel_sock);
-		if (tunnel == NULL)
-			goto end_put_sess;
-
+		tunnel = session->tunnel;
 		err = pppol2tp_tunnel_ioctl(tunnel, cmd, arg);
-		sock_put(ps->tunnel_sock);
 		goto end_put_sess;
 	}
 
@@ -1273,7 +1288,6 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
 				       int optname, int val)
 {
 	int err = 0;
-	struct pppol2tp_session *ps = l2tp_session_priv(session);
 
 	switch (optname) {
 	case PPPOL2TP_SO_RECVSEQ:
@@ -1294,8 +1308,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
 		}
 		session->send_seq = !!val;
 		{
-			struct sock *ssk      = ps->sock;
-			struct pppox_sock *po = pppox_sk(ssk);
+			struct pppox_sock *po = pppox_sk(sk);
+
 			po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ :
 				PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
 		}
@@ -1348,7 +1362,6 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname,
 	struct sock *sk = sock->sk;
 	struct l2tp_session *session;
 	struct l2tp_tunnel *tunnel;
-	struct pppol2tp_session *ps;
 	int val;
 	int err;
 
@@ -1373,20 +1386,14 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname,
 
 	/* Special case: if session_id == 0x0000, treat as operation on tunnel
 	 */
-	ps = l2tp_session_priv(session);
 	if ((session->session_id == 0) &&
 	    (session->peer_session_id == 0)) {
-		err = -EBADF;
-		tunnel = l2tp_sock_to_tunnel(ps->tunnel_sock);
-		if (tunnel == NULL)
-			goto end_put_sess;
-
+		tunnel = session->tunnel;
 		err = pppol2tp_tunnel_setsockopt(sk, tunnel, optname, val);
-		sock_put(ps->tunnel_sock);
-	} else
+	} else {
 		err = pppol2tp_session_setsockopt(sk, session, optname, val);
+	}
 
-end_put_sess:
 	sock_put(sk);
 end:
 	return err;
@@ -1474,7 +1481,6 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname,
 	struct l2tp_tunnel *tunnel;
 	int val, len;
 	int err;
-	struct pppol2tp_session *ps;
 
 	if (level != SOL_PPPOL2TP)
 		return -EINVAL;
@@ -1498,16 +1504,10 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname,
 		goto end;
 
 	/* Special case: if session_id == 0x0000, treat as operation on tunnel */
-	ps = l2tp_session_priv(session);
 	if ((session->session_id == 0) &&
 	    (session->peer_session_id == 0)) {
-		err = -EBADF;
-		tunnel = l2tp_sock_to_tunnel(ps->tunnel_sock);
-		if (tunnel == NULL)
-			goto end_put_sess;
-
+		tunnel = session->tunnel;
 		err = pppol2tp_tunnel_getsockopt(sk, tunnel, optname, &val);
-		sock_put(ps->tunnel_sock);
 		if (err)
 			goto end_put_sess;
 	} else {
@@ -1566,7 +1566,7 @@ static void pppol2tp_next_tunnel(struct net *net, struct pppol2tp_seq_data *pd)
 
 static void pppol2tp_next_session(struct net *net, struct pppol2tp_seq_data *pd)
 {
-	pd->session = l2tp_session_get_nth(pd->tunnel, pd->session_idx, true);
+	pd->session = l2tp_session_get_nth(pd->tunnel, pd->session_idx);
 	pd->session_idx++;
 
 	if (pd->session == NULL) {
@@ -1634,8 +1634,9 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
 {
 	struct l2tp_session *session = v;
 	struct l2tp_tunnel *tunnel = session->tunnel;
-	struct pppol2tp_session *ps = l2tp_session_priv(session);
-	struct pppox_sock *po = pppox_sk(ps->sock);
+	unsigned char state;
+	char user_data_ok;
+	struct sock *sk;
 	u32 ip = 0;
 	u16 port = 0;
 
@@ -1645,6 +1646,15 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
 		port = ntohs(inet->inet_sport);
 	}
 
+	sk = pppol2tp_session_get_sock(session);
+	if (sk) {
+		state = sk->sk_state;
+		user_data_ok = (session == sk->sk_user_data) ? 'Y' : 'N';
+	} else {
+		state = 0;
+		user_data_ok = 'N';
+	}
+
 	seq_printf(m, "  SESSION '%s' %08X/%d %04X/%04X -> "
 		   "%04X/%04X %d %c\n",
 		   session->name, ip, port,
@@ -1652,9 +1662,7 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
 		   session->session_id,
 		   tunnel->peer_tunnel_id,
 		   session->peer_session_id,
-		   ps->sock->sk_state,
-		   (session == ps->sock->sk_user_data) ?
-		   'Y' : 'N');
+		   state, user_data_ok);
 	seq_printf(m, "   %d/%d/%c/%c/%s %08x %u\n",
 		   session->mtu, session->mru,
 		   session->recv_seq ? 'R' : '-',
@@ -1671,8 +1679,12 @@ static void pppol2tp_seq_session_show(struct seq_file *m, void *v)
 		   atomic_long_read(&session->stats.rx_bytes),
 		   atomic_long_read(&session->stats.rx_errors));
 
-	if (po)
+	if (sk) {
+		struct pppox_sock *po = pppox_sk(sk);
+
 		seq_printf(m, "   interface %s\n", ppp_dev_name(&po->chan));
+		sock_put(sk);
+	}
 }
 
 static int pppol2tp_seq_show(struct seq_file *m, void *v)
@@ -1697,8 +1709,6 @@ static int pppol2tp_seq_show(struct seq_file *m, void *v)
 		pppol2tp_seq_tunnel_show(m, pd->tunnel);
 	} else {
 		pppol2tp_seq_session_show(m, pd->session);
-		if (pd->session->deref)
-			pd->session->deref(pd->session);
 		l2tp_session_dec_refcount(pd->session);
 	}