summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--include/net/sch_generic.h4
-rw-r--r--net/sched/cls_api.c83
2 files changed, 83 insertions, 4 deletions
diff --git a/include/net/sch_generic.h b/include/net/sch_generic.h
index 637548d54b3e..d80acda231ae 100644
--- a/include/net/sch_generic.h
+++ b/include/net/sch_generic.h
@@ -15,6 +15,7 @@
 #include <linux/mutex.h>
 #include <linux/rwsem.h>
 #include <linux/atomic.h>
+#include <linux/hashtable.h>
 #include <net/gen_stats.h>
 #include <net/rtnetlink.h>
 #include <net/flow_offload.h>
@@ -362,6 +363,7 @@ struct tcf_proto {
 	bool			deleting;
 	refcount_t		refcnt;
 	struct rcu_head		rcu;
+	struct hlist_node	destroy_ht_node;
 };
 
 struct qdisc_skb_cb {
@@ -414,6 +416,8 @@ struct tcf_block {
 		struct list_head filter_chain_list;
 	} chain0;
 	struct rcu_head rcu;
+	DECLARE_HASHTABLE(proto_destroy_ht, 7);
+	struct mutex proto_destroy_lock; /* Lock for proto_destroy hashtable. */
 };
 
 #ifdef CONFIG_PROVE_LOCKING
diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c
index 8717c0b26c90..20d60b8fcb70 100644
--- a/net/sched/cls_api.c
+++ b/net/sched/cls_api.c
@@ -21,6 +21,7 @@
 #include <linux/slab.h>
 #include <linux/idr.h>
 #include <linux/rhashtable.h>
+#include <linux/jhash.h>
 #include <net/net_namespace.h>
 #include <net/sock.h>
 #include <net/netlink.h>
@@ -47,6 +48,62 @@ static LIST_HEAD(tcf_proto_base);
 /* Protects list of registered TC modules. It is pure SMP lock. */
 static DEFINE_RWLOCK(cls_mod_lock);
 
+static u32 destroy_obj_hashfn(const struct tcf_proto *tp)
+{
+	return jhash_3words(tp->chain->index, tp->prio,
+			    (__force __u32)tp->protocol, 0);
+}
+
+static void tcf_proto_signal_destroying(struct tcf_chain *chain,
+					struct tcf_proto *tp)
+{
+	struct tcf_block *block = chain->block;
+
+	mutex_lock(&block->proto_destroy_lock);
+	hash_add_rcu(block->proto_destroy_ht, &tp->destroy_ht_node,
+		     destroy_obj_hashfn(tp));
+	mutex_unlock(&block->proto_destroy_lock);
+}
+
+static bool tcf_proto_cmp(const struct tcf_proto *tp1,
+			  const struct tcf_proto *tp2)
+{
+	return tp1->chain->index == tp2->chain->index &&
+	       tp1->prio == tp2->prio &&
+	       tp1->protocol == tp2->protocol;
+}
+
+static bool tcf_proto_exists_destroying(struct tcf_chain *chain,
+					struct tcf_proto *tp)
+{
+	u32 hash = destroy_obj_hashfn(tp);
+	struct tcf_proto *iter;
+	bool found = false;
+
+	rcu_read_lock();
+	hash_for_each_possible_rcu(chain->block->proto_destroy_ht, iter,
+				   destroy_ht_node, hash) {
+		if (tcf_proto_cmp(tp, iter)) {
+			found = true;
+			break;
+		}
+	}
+	rcu_read_unlock();
+
+	return found;
+}
+
+static void
+tcf_proto_signal_destroyed(struct tcf_chain *chain, struct tcf_proto *tp)
+{
+	struct tcf_block *block = chain->block;
+
+	mutex_lock(&block->proto_destroy_lock);
+	if (hash_hashed(&tp->destroy_ht_node))
+		hash_del_rcu(&tp->destroy_ht_node);
+	mutex_unlock(&block->proto_destroy_lock);
+}
+
 /* Find classifier type by string name */
 
 static const struct tcf_proto_ops *__tcf_proto_lookup_ops(const char *kind)
@@ -234,9 +291,11 @@ static void tcf_proto_get(struct tcf_proto *tp)
 static void tcf_chain_put(struct tcf_chain *chain);
 
 static void tcf_proto_destroy(struct tcf_proto *tp, bool rtnl_held,
-			      struct netlink_ext_ack *extack)
+			      bool sig_destroy, struct netlink_ext_ack *extack)
 {
 	tp->ops->destroy(tp, rtnl_held, extack);
+	if (sig_destroy)
+		tcf_proto_signal_destroyed(tp->chain, tp);
 	tcf_chain_put(tp->chain);
 	module_put(tp->ops->owner);
 	kfree_rcu(tp, rcu);
@@ -246,7 +305,7 @@ static void tcf_proto_put(struct tcf_proto *tp, bool rtnl_held,
 			  struct netlink_ext_ack *extack)
 {
 	if (refcount_dec_and_test(&tp->refcnt))
-		tcf_proto_destroy(tp, rtnl_held, extack);
+		tcf_proto_destroy(tp, rtnl_held, true, extack);
 }
 
 static int walker_check_empty(struct tcf_proto *tp, void *fh,
@@ -370,6 +429,7 @@ static bool tcf_chain_detach(struct tcf_chain *chain)
 static void tcf_block_destroy(struct tcf_block *block)
 {
 	mutex_destroy(&block->lock);
+	mutex_destroy(&block->proto_destroy_lock);
 	kfree_rcu(block, rcu);
 }
 
@@ -545,6 +605,12 @@ static void tcf_chain_flush(struct tcf_chain *chain, bool rtnl_held)
 
 	mutex_lock(&chain->filter_chain_lock);
 	tp = tcf_chain_dereference(chain->filter_chain, chain);
+	while (tp) {
+		tp_next = rcu_dereference_protected(tp->next, 1);
+		tcf_proto_signal_destroying(chain, tp);
+		tp = tp_next;
+	}
+	tp = tcf_chain_dereference(chain->filter_chain, chain);
 	RCU_INIT_POINTER(chain->filter_chain, NULL);
 	tcf_chain0_head_change(chain, NULL);
 	chain->flushing = true;
@@ -844,6 +910,7 @@ static struct tcf_block *tcf_block_create(struct net *net, struct Qdisc *q,
 		return ERR_PTR(-ENOMEM);
 	}
 	mutex_init(&block->lock);
+	mutex_init(&block->proto_destroy_lock);
 	init_rwsem(&block->cb_lock);
 	flow_block_init(&block->flow_block);
 	INIT_LIST_HEAD(&block->chain_list);
@@ -1621,6 +1688,12 @@ static struct tcf_proto *tcf_chain_tp_insert_unique(struct tcf_chain *chain,
 
 	mutex_lock(&chain->filter_chain_lock);
 
+	if (tcf_proto_exists_destroying(chain, tp_new)) {
+		mutex_unlock(&chain->filter_chain_lock);
+		tcf_proto_destroy(tp_new, rtnl_held, false, NULL);
+		return ERR_PTR(-EAGAIN);
+	}
+
 	tp = tcf_chain_tp_find(chain, &chain_info,
 			       protocol, prio, false);
 	if (!tp)
@@ -1628,10 +1701,10 @@ static struct tcf_proto *tcf_chain_tp_insert_unique(struct tcf_chain *chain,
 	mutex_unlock(&chain->filter_chain_lock);
 
 	if (tp) {
-		tcf_proto_destroy(tp_new, rtnl_held, NULL);
+		tcf_proto_destroy(tp_new, rtnl_held, false, NULL);
 		tp_new = tp;
 	} else if (err) {
-		tcf_proto_destroy(tp_new, rtnl_held, NULL);
+		tcf_proto_destroy(tp_new, rtnl_held, false, NULL);
 		tp_new = ERR_PTR(err);
 	}
 
@@ -1669,6 +1742,7 @@ static void tcf_chain_tp_delete_empty(struct tcf_chain *chain,
 		return;
 	}
 
+	tcf_proto_signal_destroying(chain, tp);
 	next = tcf_chain_dereference(chain_info.next, chain);
 	if (tp == chain->filter_chain)
 		tcf_chain0_head_change(chain, next);
@@ -2188,6 +2262,7 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n,
 		err = -EINVAL;
 		goto errout_locked;
 	} else if (t->tcm_handle == 0) {
+		tcf_proto_signal_destroying(chain, tp);
 		tcf_chain_tp_remove(chain, &chain_info, tp);
 		mutex_unlock(&chain->filter_chain_lock);