summary refs log tree commit diff
path: root/kernel
diff options
context:
space:
mode:
authorLinus Torvalds <torvalds@linux-foundation.org>2019-09-28 12:39:07 -0700
committerLinus Torvalds <torvalds@linux-foundation.org>2019-09-28 12:39:07 -0700
commit9c5efe9ae7df78600c0ee7bcce27516eb687fa6e (patch)
tree158cfb9720d876e68a14a4cccaffeb58fb7baac5 /kernel
parentaefcf2f4b58155d27340ba5f9ddbe9513da8286d (diff)
parent4892f51ad54ddff2883a60b6ad4323c1f632a9d6 (diff)
downloadlinux-9c5efe9ae7df78600c0ee7bcce27516eb687fa6e.tar.gz
Merge branch 'sched-urgent-for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/tip/tip
Pull scheduler fixes from Ingo Molnar:

 - Apply a number of membarrier related fixes and cleanups, which fixes
   a use-after-free race in the membarrier code

 - Introduce proper RCU protection for tasks on the runqueue - to get
   rid of the subtle task_rcu_dereference() interface that was easy to
   get wrong

 - Misc fixes, but also an EAS speedup

* 'sched-urgent-for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/tip/tip:
  sched/fair: Avoid redundant EAS calculation
  sched/core: Remove double update_max_interval() call on CPU startup
  sched/core: Fix preempt_schedule() interrupt return comment
  sched/fair: Fix -Wunused-but-set-variable warnings
  sched/core: Fix migration to invalid CPU in __set_cpus_allowed_ptr()
  sched/membarrier: Return -ENOMEM to userspace on memory allocation failure
  sched/membarrier: Skip IPIs when mm->mm_users == 1
  selftests, sched/membarrier: Add multi-threaded test
  sched/membarrier: Fix p->mm->membarrier_state racy load
  sched/membarrier: Call sync_core only before usermode for same mm
  sched/membarrier: Remove redundant check
  sched/membarrier: Fix private expedited registration check
  tasks, sched/core: RCUify the assignment of rq->curr
  tasks, sched/core: With a grace period after finish_task_switch(), remove unnecessary code
  tasks, sched/core: Ensure tasks are available for a grace period after leaving the runqueue
  tasks: Add a count of task RCU users
  sched/core: Convert vcpu_is_preempted() from macro to an inline function
  sched/fair: Remove unused cfs_rq_clock_task() function
Diffstat (limited to 'kernel')
-rw-r--r--kernel/exit.c74
-rw-r--r--kernel/fork.c8
-rw-r--r--kernel/sched/core.c28
-rw-r--r--kernel/sched/fair.c39
-rw-r--r--kernel/sched/membarrier.c239
-rw-r--r--kernel/sched/sched.h34
6 files changed, 221 insertions, 201 deletions
diff --git a/kernel/exit.c b/kernel/exit.c
index 22ab6a4bdc51..a46a50d67002 100644
--- a/kernel/exit.c
+++ b/kernel/exit.c
@@ -182,6 +182,11 @@ static void delayed_put_task_struct(struct rcu_head *rhp)
 	put_task_struct(tsk);
 }
 
+void put_task_struct_rcu_user(struct task_struct *task)
+{
+	if (refcount_dec_and_test(&task->rcu_users))
+		call_rcu(&task->rcu, delayed_put_task_struct);
+}
 
 void release_task(struct task_struct *p)
 {
@@ -222,76 +227,13 @@ repeat:
 
 	write_unlock_irq(&tasklist_lock);
 	release_thread(p);
-	call_rcu(&p->rcu, delayed_put_task_struct);
+	put_task_struct_rcu_user(p);
 
 	p = leader;
 	if (unlikely(zap_leader))
 		goto repeat;
 }
 
-/*
- * Note that if this function returns a valid task_struct pointer (!NULL)
- * task->usage must remain >0 for the duration of the RCU critical section.
- */
-struct task_struct *task_rcu_dereference(struct task_struct **ptask)
-{
-	struct sighand_struct *sighand;
-	struct task_struct *task;
-
-	/*
-	 * We need to verify that release_task() was not called and thus
-	 * delayed_put_task_struct() can't run and drop the last reference
-	 * before rcu_read_unlock(). We check task->sighand != NULL,
-	 * but we can read the already freed and reused memory.
-	 */
-retry:
-	task = rcu_dereference(*ptask);
-	if (!task)
-		return NULL;
-
-	probe_kernel_address(&task->sighand, sighand);
-
-	/*
-	 * Pairs with atomic_dec_and_test() in put_task_struct(). If this task
-	 * was already freed we can not miss the preceding update of this
-	 * pointer.
-	 */
-	smp_rmb();
-	if (unlikely(task != READ_ONCE(*ptask)))
-		goto retry;
-
-	/*
-	 * We've re-checked that "task == *ptask", now we have two different
-	 * cases:
-	 *
-	 * 1. This is actually the same task/task_struct. In this case
-	 *    sighand != NULL tells us it is still alive.
-	 *
-	 * 2. This is another task which got the same memory for task_struct.
-	 *    We can't know this of course, and we can not trust
-	 *    sighand != NULL.
-	 *
-	 *    In this case we actually return a random value, but this is
-	 *    correct.
-	 *
-	 *    If we return NULL - we can pretend that we actually noticed that
-	 *    *ptask was updated when the previous task has exited. Or pretend
-	 *    that probe_slab_address(&sighand) reads NULL.
-	 *
-	 *    If we return the new task (because sighand is not NULL for any
-	 *    reason) - this is fine too. This (new) task can't go away before
-	 *    another gp pass.
-	 *
-	 *    And note: We could even eliminate the false positive if re-read
-	 *    task->sighand once again to avoid the falsely NULL. But this case
-	 *    is very unlikely so we don't care.
-	 */
-	if (!sighand)
-		return NULL;
-
-	return task;
-}
-
 void rcuwait_wake_up(struct rcuwait *w)
 {
 	struct task_struct *task;
@@ -311,10 +253,6 @@ void rcuwait_wake_up(struct rcuwait *w)
 	 */
 	smp_mb(); /* (B) */
 
-	/*
-	 * Avoid using task_rcu_dereference() magic as long as we are careful,
-	 * see comment in rcuwait_wait_event() regarding ->exit_state.
-	 */
 	task = rcu_dereference(w->task);
 	if (task)
 		wake_up_process(task);
diff --git a/kernel/fork.c b/kernel/fork.c
index 60763c043aa3..f9572f416126 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -915,10 +915,12 @@ static struct task_struct *dup_task_struct(struct task_struct *orig, int node)
 		tsk->cpus_ptr = &tsk->cpus_mask;
 
 	/*
-	 * One for us, one for whoever does the "release_task()" (usually
-	 * parent)
+	 * One for the user space visible state that goes away when reaped.
+	 * One for the scheduler.
 	 */
-	refcount_set(&tsk->usage, 2);
+	refcount_set(&tsk->rcu_users, 2);
+	/* One for the rcu users */
+	refcount_set(&tsk->usage, 1);
 #ifdef CONFIG_BLK_DEV_IO_TRACE
 	tsk->btrace_seq = 0;
 #endif
diff --git a/kernel/sched/core.c b/kernel/sched/core.c
index f9a1346a5fa9..7880f4f64d0e 100644
--- a/kernel/sched/core.c
+++ b/kernel/sched/core.c
@@ -1656,7 +1656,8 @@ static int __set_cpus_allowed_ptr(struct task_struct *p,
 	if (cpumask_equal(p->cpus_ptr, new_mask))
 		goto out;
 
-	if (!cpumask_intersects(new_mask, cpu_valid_mask)) {
+	dest_cpu = cpumask_any_and(cpu_valid_mask, new_mask);
+	if (dest_cpu >= nr_cpu_ids) {
 		ret = -EINVAL;
 		goto out;
 	}
@@ -1677,7 +1678,6 @@ static int __set_cpus_allowed_ptr(struct task_struct *p,
 	if (cpumask_test_cpu(task_cpu(p), new_mask))
 		goto out;
 
-	dest_cpu = cpumask_any_and(cpu_valid_mask, new_mask);
 	if (task_running(rq, p) || p->state == TASK_WAKING) {
 		struct migration_arg arg = { p, dest_cpu };
 		/* Need help from migration thread: drop lock and wait. */
@@ -3254,7 +3254,7 @@ static struct rq *finish_task_switch(struct task_struct *prev)
 		/* Task is done with its stack. */
 		put_task_stack(prev);
 
-		put_task_struct(prev);
+		put_task_struct_rcu_user(prev);
 	}
 
 	tick_nohz_task_switch();
@@ -3358,15 +3358,15 @@ context_switch(struct rq *rq, struct task_struct *prev,
 		else
 			prev->active_mm = NULL;
 	} else {                                        // to user
+		membarrier_switch_mm(rq, prev->active_mm, next->mm);
 		/*
 		 * sys_membarrier() requires an smp_mb() between setting
-		 * rq->curr and returning to userspace.
+		 * rq->curr / membarrier_switch_mm() and returning to userspace.
 		 *
 		 * The below provides this either through switch_mm(), or in
 		 * case 'prev->active_mm == next->mm' through
 		 * finish_task_switch()'s mmdrop().
 		 */
-
 		switch_mm_irqs_off(prev->active_mm, next->mm, next);
 
 		if (!prev->mm) {                        // from kernel
@@ -4042,7 +4042,11 @@ static void __sched notrace __schedule(bool preempt)
 
 	if (likely(prev != next)) {
 		rq->nr_switches++;
-		rq->curr = next;
+		/*
+		 * RCU users of rcu_dereference(rq->curr) may not see
+		 * changes to task_struct made by pick_next_task().
+		 */
+		RCU_INIT_POINTER(rq->curr, next);
 		/*
 		 * The membarrier system call requires each architecture
 		 * to have a full memory barrier after updating
@@ -4223,9 +4227,8 @@ static void __sched notrace preempt_schedule_common(void)
 
 #ifdef CONFIG_PREEMPTION
 /*
- * this is the entry point to schedule() from in-kernel preemption
- * off of preempt_enable. Kernel preemptions off return from interrupt
- * occur there and call schedule directly.
+ * This is the entry point to schedule() from in-kernel preemption
+ * off of preempt_enable.
  */
 asmlinkage __visible void __sched notrace preempt_schedule(void)
 {
@@ -4296,7 +4299,7 @@ EXPORT_SYMBOL_GPL(preempt_schedule_notrace);
 #endif /* CONFIG_PREEMPTION */
 
 /*
- * this is the entry point to schedule() from kernel preemption
+ * This is the entry point to schedule() from kernel preemption
  * off of irq context.
  * Note, that this is called and return with irqs disabled. This will
  * protect us against recursive calling from irq.
@@ -6069,7 +6072,8 @@ void init_idle(struct task_struct *idle, int cpu)
 	__set_task_cpu(idle, cpu);
 	rcu_read_unlock();
 
-	rq->curr = rq->idle = idle;
+	rq->idle = idle;
+	rcu_assign_pointer(rq->curr, idle);
 	idle->on_rq = TASK_ON_RQ_QUEUED;
 #ifdef CONFIG_SMP
 	idle->on_cpu = 1;
@@ -6430,8 +6434,6 @@ int sched_cpu_activate(unsigned int cpu)
 	}
 	rq_unlock_irqrestore(rq, &rf);
 
-	update_max_interval();
-
 	return 0;
 }
 
diff --git a/kernel/sched/fair.c b/kernel/sched/fair.c
index d4bbf68c3161..83ab35e2374f 100644
--- a/kernel/sched/fair.c
+++ b/kernel/sched/fair.c
@@ -749,7 +749,6 @@ void init_entity_runnable_average(struct sched_entity *se)
 	/* when this task enqueue'ed, it will contribute to its cfs_rq's load_avg */
 }
 
-static inline u64 cfs_rq_clock_task(struct cfs_rq *cfs_rq);
 static void attach_entity_cfs_rq(struct sched_entity *se);
 
 /*
@@ -1603,7 +1602,7 @@ static void task_numa_compare(struct task_numa_env *env,
 		return;
 
 	rcu_read_lock();
-	cur = task_rcu_dereference(&dst_rq->curr);
+	cur = rcu_dereference(dst_rq->curr);
 	if (cur && ((cur->flags & PF_EXITING) || is_idle_task(cur)))
 		cur = NULL;
 
@@ -4354,21 +4353,16 @@ static inline u64 sched_cfs_bandwidth_slice(void)
 }
 
 /*
- * Replenish runtime according to assigned quota and update expiration time.
- * We use sched_clock_cpu directly instead of rq->clock to avoid adding
- * additional synchronization around rq->lock.
+ * Replenish runtime according to assigned quota. We use sched_clock_cpu
+ * directly instead of rq->clock to avoid adding additional synchronization
+ * around rq->lock.
  *
  * requires cfs_b->lock
  */
 void __refill_cfs_bandwidth_runtime(struct cfs_bandwidth *cfs_b)
 {
-	u64 now;
-
-	if (cfs_b->quota == RUNTIME_INF)
-		return;
-
-	now = sched_clock_cpu(smp_processor_id());
-	cfs_b->runtime = cfs_b->quota;
+	if (cfs_b->quota != RUNTIME_INF)
+		cfs_b->runtime = cfs_b->quota;
 }
 
 static inline struct cfs_bandwidth *tg_cfs_bandwidth(struct task_group *tg)
@@ -4376,15 +4370,6 @@ static inline struct cfs_bandwidth *tg_cfs_bandwidth(struct task_group *tg)
 	return &tg->cfs_bandwidth;
 }
 
-/* rq->task_clock normalized against any time this cfs_rq has spent throttled */
-static inline u64 cfs_rq_clock_task(struct cfs_rq *cfs_rq)
-{
-	if (unlikely(cfs_rq->throttle_count))
-		return cfs_rq->throttled_clock_task - cfs_rq->throttled_clock_task_time;
-
-	return rq_clock_task(rq_of(cfs_rq)) - cfs_rq->throttled_clock_task_time;
-}
-
 /* returns 0 on failure to allocate runtime */
 static int assign_cfs_rq_runtime(struct cfs_rq *cfs_rq)
 {
@@ -4476,7 +4461,6 @@ static int tg_unthrottle_up(struct task_group *tg, void *data)
 
 	cfs_rq->throttle_count--;
 	if (!cfs_rq->throttle_count) {
-		/* adjust cfs_rq_clock_task() */
 		cfs_rq->throttled_clock_task_time += rq_clock_task(rq) -
 					     cfs_rq->throttled_clock_task;
 
@@ -4994,15 +4978,13 @@ static void init_cfs_rq_runtime(struct cfs_rq *cfs_rq)
 
 void start_cfs_bandwidth(struct cfs_bandwidth *cfs_b)
 {
-	u64 overrun;
-
 	lockdep_assert_held(&cfs_b->lock);
 
 	if (cfs_b->period_active)
 		return;
 
 	cfs_b->period_active = 1;
-	overrun = hrtimer_forward_now(&cfs_b->period_timer, cfs_b->period);
+	hrtimer_forward_now(&cfs_b->period_timer, cfs_b->period);
 	hrtimer_start_expires(&cfs_b->period_timer, HRTIMER_MODE_ABS_PINNED);
 }
 
@@ -5080,11 +5062,6 @@ static inline bool cfs_bandwidth_used(void)
 	return false;
 }
 
-static inline u64 cfs_rq_clock_task(struct cfs_rq *cfs_rq)
-{
-	return rq_clock_task(rq_of(cfs_rq));
-}
-
 static void account_cfs_rq_runtime(struct cfs_rq *cfs_rq, u64 delta_exec) {}
 static bool check_cfs_rq_runtime(struct cfs_rq *cfs_rq) { return false; }
 static void check_enqueue_throttle(struct cfs_rq *cfs_rq) {}
@@ -6412,7 +6389,7 @@ static int find_energy_efficient_cpu(struct task_struct *p, int prev_cpu)
 		}
 
 		/* Evaluate the energy impact of using this CPU. */
-		if (max_spare_cap_cpu >= 0) {
+		if (max_spare_cap_cpu >= 0 && max_spare_cap_cpu != prev_cpu) {
 			cur_delta = compute_energy(p, max_spare_cap_cpu, pd);
 			cur_delta -= base_energy_pd;
 			if (cur_delta < best_delta) {
diff --git a/kernel/sched/membarrier.c b/kernel/sched/membarrier.c
index aa8d75804108..a39bed2c784f 100644
--- a/kernel/sched/membarrier.c
+++ b/kernel/sched/membarrier.c
@@ -30,10 +30,42 @@ static void ipi_mb(void *info)
 	smp_mb();	/* IPIs should be serializing but paranoid. */
 }
 
+static void ipi_sync_rq_state(void *info)
+{
+	struct mm_struct *mm = (struct mm_struct *) info;
+
+	if (current->mm != mm)
+		return;
+	this_cpu_write(runqueues.membarrier_state,
+		       atomic_read(&mm->membarrier_state));
+	/*
+	 * Issue a memory barrier after setting
+	 * MEMBARRIER_STATE_GLOBAL_EXPEDITED in the current runqueue to
+	 * guarantee that no memory access following registration is reordered
+	 * before registration.
+	 */
+	smp_mb();
+}
+
+void membarrier_exec_mmap(struct mm_struct *mm)
+{
+	/*
+	 * Issue a memory barrier before clearing membarrier_state to
+	 * guarantee that no memory access prior to exec is reordered after
+	 * clearing this state.
+	 */
+	smp_mb();
+	atomic_set(&mm->membarrier_state, 0);
+	/*
+	 * Keep the runqueue membarrier_state in sync with this mm
+	 * membarrier_state.
+	 */
+	this_cpu_write(runqueues.membarrier_state, 0);
+}
+
 static int membarrier_global_expedited(void)
 {
 	int cpu;
-	bool fallback = false;
 	cpumask_var_t tmpmask;
 
 	if (num_online_cpus() == 1)
@@ -45,17 +77,11 @@ static int membarrier_global_expedited(void)
 	 */
 	smp_mb();	/* system call entry is not a mb. */
 
-	/*
-	 * Expedited membarrier commands guarantee that they won't
-	 * block, hence the GFP_NOWAIT allocation flag and fallback
-	 * implementation.
-	 */
-	if (!zalloc_cpumask_var(&tmpmask, GFP_NOWAIT)) {
-		/* Fallback for OOM. */
-		fallback = true;
-	}
+	if (!zalloc_cpumask_var(&tmpmask, GFP_KERNEL))
+		return -ENOMEM;
 
 	cpus_read_lock();
+	rcu_read_lock();
 	for_each_online_cpu(cpu) {
 		struct task_struct *p;
 
@@ -70,23 +96,28 @@ static int membarrier_global_expedited(void)
 		if (cpu == raw_smp_processor_id())
 			continue;
 
-		rcu_read_lock();
-		p = task_rcu_dereference(&cpu_rq(cpu)->curr);
-		if (p && p->mm && (atomic_read(&p->mm->membarrier_state) &
-				   MEMBARRIER_STATE_GLOBAL_EXPEDITED)) {
-			if (!fallback)
-				__cpumask_set_cpu(cpu, tmpmask);
-			else
-				smp_call_function_single(cpu, ipi_mb, NULL, 1);
-		}
-		rcu_read_unlock();
-	}
-	if (!fallback) {
-		preempt_disable();
-		smp_call_function_many(tmpmask, ipi_mb, NULL, 1);
-		preempt_enable();
-		free_cpumask_var(tmpmask);
+		if (!(READ_ONCE(cpu_rq(cpu)->membarrier_state) &
+		    MEMBARRIER_STATE_GLOBAL_EXPEDITED))
+			continue;
+
+		/*
+		 * Skip the CPU if it runs a kernel thread. The scheduler
+		 * leaves the prior task mm in place as an optimization when
+		 * scheduling a kthread.
+		 */
+		p = rcu_dereference(cpu_rq(cpu)->curr);
+		if (p->flags & PF_KTHREAD)
+			continue;
+
+		__cpumask_set_cpu(cpu, tmpmask);
 	}
+	rcu_read_unlock();
+
+	preempt_disable();
+	smp_call_function_many(tmpmask, ipi_mb, NULL, 1);
+	preempt_enable();
+
+	free_cpumask_var(tmpmask);
 	cpus_read_unlock();
 
 	/*
@@ -101,22 +132,22 @@ static int membarrier_global_expedited(void)
 static int membarrier_private_expedited(int flags)
 {
 	int cpu;
-	bool fallback = false;
 	cpumask_var_t tmpmask;
+	struct mm_struct *mm = current->mm;
 
 	if (flags & MEMBARRIER_FLAG_SYNC_CORE) {
 		if (!IS_ENABLED(CONFIG_ARCH_HAS_MEMBARRIER_SYNC_CORE))
 			return -EINVAL;
-		if (!(atomic_read(&current->mm->membarrier_state) &
+		if (!(atomic_read(&mm->membarrier_state) &
 		      MEMBARRIER_STATE_PRIVATE_EXPEDITED_SYNC_CORE_READY))
 			return -EPERM;
 	} else {
-		if (!(atomic_read(&current->mm->membarrier_state) &
+		if (!(atomic_read(&mm->membarrier_state) &
 		      MEMBARRIER_STATE_PRIVATE_EXPEDITED_READY))
 			return -EPERM;
 	}
 
-	if (num_online_cpus() == 1)
+	if (atomic_read(&mm->mm_users) == 1 || num_online_cpus() == 1)
 		return 0;
 
 	/*
@@ -125,17 +156,11 @@ static int membarrier_private_expedited(int flags)
 	 */
 	smp_mb();	/* system call entry is not a mb. */
 
-	/*
-	 * Expedited membarrier commands guarantee that they won't
-	 * block, hence the GFP_NOWAIT allocation flag and fallback
-	 * implementation.
-	 */
-	if (!zalloc_cpumask_var(&tmpmask, GFP_NOWAIT)) {
-		/* Fallback for OOM. */
-		fallback = true;
-	}
+	if (!zalloc_cpumask_var(&tmpmask, GFP_KERNEL))
+		return -ENOMEM;
 
 	cpus_read_lock();
+	rcu_read_lock();
 	for_each_online_cpu(cpu) {
 		struct task_struct *p;
 
@@ -150,21 +175,17 @@ static int membarrier_private_expedited(int flags)
 		if (cpu == raw_smp_processor_id())
 			continue;
 		rcu_read_lock();
-		p = task_rcu_dereference(&cpu_rq(cpu)->curr);
-		if (p && p->mm == current->mm) {
-			if (!fallback)
-				__cpumask_set_cpu(cpu, tmpmask);
-			else
-				smp_call_function_single(cpu, ipi_mb, NULL, 1);
-		}
-		rcu_read_unlock();
-	}
-	if (!fallback) {
-		preempt_disable();
-		smp_call_function_many(tmpmask, ipi_mb, NULL, 1);
-		preempt_enable();
-		free_cpumask_var(tmpmask);
+		p = rcu_dereference(cpu_rq(cpu)->curr);
+		if (p && p->mm == mm)
+			__cpumask_set_cpu(cpu, tmpmask);
 	}
+	rcu_read_unlock();
+
+	preempt_disable();
+	smp_call_function_many(tmpmask, ipi_mb, NULL, 1);
+	preempt_enable();
+
+	free_cpumask_var(tmpmask);
 	cpus_read_unlock();
 
 	/*
@@ -177,32 +198,78 @@ static int membarrier_private_expedited(int flags)
 	return 0;
 }
 
+static int sync_runqueues_membarrier_state(struct mm_struct *mm)
+{
+	int membarrier_state = atomic_read(&mm->membarrier_state);
+	cpumask_var_t tmpmask;
+	int cpu;
+
+	if (atomic_read(&mm->mm_users) == 1 || num_online_cpus() == 1) {
+		this_cpu_write(runqueues.membarrier_state, membarrier_state);
+
+		/*
+		 * For single mm user, we can simply issue a memory barrier
+		 * after setting MEMBARRIER_STATE_GLOBAL_EXPEDITED in the
+		 * mm and in the current runqueue to guarantee that no memory
+		 * access following registration is reordered before
+		 * registration.
+		 */
+		smp_mb();
+		return 0;
+	}
+
+	if (!zalloc_cpumask_var(&tmpmask, GFP_KERNEL))
+		return -ENOMEM;
+
+	/*
+	 * For mm with multiple users, we need to ensure all future
+	 * scheduler executions will observe @mm's new membarrier
+	 * state.
+	 */
+	synchronize_rcu();
+
+	/*
+	 * For each cpu runqueue, if the task's mm match @mm, ensure that all
+	 * @mm's membarrier state set bits are also set in in the runqueue's
+	 * membarrier state. This ensures that a runqueue scheduling
+	 * between threads which are users of @mm has its membarrier state
+	 * updated.
+	 */
+	cpus_read_lock();
+	rcu_read_lock();
+	for_each_online_cpu(cpu) {
+		struct rq *rq = cpu_rq(cpu);
+		struct task_struct *p;
+
+		p = rcu_dereference(rq->curr);
+		if (p && p->mm == mm)
+			__cpumask_set_cpu(cpu, tmpmask);
+	}
+	rcu_read_unlock();
+
+	preempt_disable();
+	smp_call_function_many(tmpmask, ipi_sync_rq_state, mm, 1);
+	preempt_enable();
+
+	free_cpumask_var(tmpmask);
+	cpus_read_unlock();
+
+	return 0;
+}
+
 static int membarrier_register_global_expedited(void)
 {
 	struct task_struct *p = current;
 	struct mm_struct *mm = p->mm;
+	int ret;
 
 	if (atomic_read(&mm->membarrier_state) &
 	    MEMBARRIER_STATE_GLOBAL_EXPEDITED_READY)
 		return 0;
 	atomic_or(MEMBARRIER_STATE_GLOBAL_EXPEDITED, &mm->membarrier_state);
-	if (atomic_read(&mm->mm_users) == 1 && get_nr_threads(p) == 1) {
-		/*
-		 * For single mm user, single threaded process, we can
-		 * simply issue a memory barrier after setting
-		 * MEMBARRIER_STATE_GLOBAL_EXPEDITED to guarantee that
-		 * no memory access following registration is reordered
-		 * before registration.
-		 */
-		smp_mb();
-	} else {
-		/*
-		 * For multi-mm user threads, we need to ensure all
-		 * future scheduler executions will observe the new
-		 * thread flag state for this mm.
-		 */
-		synchronize_rcu();
-	}
+	ret = sync_runqueues_membarrier_state(mm);
+	if (ret)
+		return ret;
 	atomic_or(MEMBARRIER_STATE_GLOBAL_EXPEDITED_READY,
 		  &mm->membarrier_state);
 
@@ -213,12 +280,15 @@ static int membarrier_register_private_expedited(int flags)
 {
 	struct task_struct *p = current;
 	struct mm_struct *mm = p->mm;
-	int state = MEMBARRIER_STATE_PRIVATE_EXPEDITED_READY;
+	int ready_state = MEMBARRIER_STATE_PRIVATE_EXPEDITED_READY,
+	    set_state = MEMBARRIER_STATE_PRIVATE_EXPEDITED,
+	    ret;
 
 	if (flags & MEMBARRIER_FLAG_SYNC_CORE) {
 		if (!IS_ENABLED(CONFIG_ARCH_HAS_MEMBARRIER_SYNC_CORE))
 			return -EINVAL;
-		state = MEMBARRIER_STATE_PRIVATE_EXPEDITED_SYNC_CORE_READY;
+		ready_state =
+			MEMBARRIER_STATE_PRIVATE_EXPEDITED_SYNC_CORE_READY;
 	}
 
 	/*
@@ -226,20 +296,15 @@ static int membarrier_register_private_expedited(int flags)
 	 * groups, which use the same mm. (CLONE_VM but not
 	 * CLONE_THREAD).
 	 */
-	if (atomic_read(&mm->membarrier_state) & state)
+	if ((atomic_read(&mm->membarrier_state) & ready_state) == ready_state)
 		return 0;
-	atomic_or(MEMBARRIER_STATE_PRIVATE_EXPEDITED, &mm->membarrier_state);
 	if (flags & MEMBARRIER_FLAG_SYNC_CORE)
-		atomic_or(MEMBARRIER_STATE_PRIVATE_EXPEDITED_SYNC_CORE,
-			  &mm->membarrier_state);
-	if (!(atomic_read(&mm->mm_users) == 1 && get_nr_threads(p) == 1)) {
-		/*
-		 * Ensure all future scheduler executions will observe the
-		 * new thread flag state for this process.
-		 */
-		synchronize_rcu();
-	}
-	atomic_or(state, &mm->membarrier_state);
+		set_state |= MEMBARRIER_STATE_PRIVATE_EXPEDITED_SYNC_CORE;
+	atomic_or(set_state, &mm->membarrier_state);
+	ret = sync_runqueues_membarrier_state(mm);
+	if (ret)
+		return ret;
+	atomic_or(ready_state, &mm->membarrier_state);
 
 	return 0;
 }
@@ -253,8 +318,10 @@ static int membarrier_register_private_expedited(int flags)
  * command specified does not exist, not available on the running
  * kernel, or if the command argument is invalid, this system call
  * returns -EINVAL. For a given command, with flags argument set to 0,
- * this system call is guaranteed to always return the same value until
- * reboot.
+ * if this system call returns -ENOSYS or -EINVAL, it is guaranteed to
+ * always return the same value until reboot. In addition, it can return
+ * -ENOMEM if there is not enough memory available to perform the system
+ * call.
  *
  * All memory accesses performed in program order from each targeted thread
  * is guaranteed to be ordered with respect to sys_membarrier(). If we use
diff --git a/kernel/sched/sched.h b/kernel/sched/sched.h
index b3cb895d14a2..0db2c1b3361e 100644
--- a/kernel/sched/sched.h
+++ b/kernel/sched/sched.h
@@ -911,6 +911,10 @@ struct rq {
 
 	atomic_t		nr_iowait;
 
+#ifdef CONFIG_MEMBARRIER
+	int membarrier_state;
+#endif
+
 #ifdef CONFIG_SMP
 	struct root_domain		*rd;
 	struct sched_domain __rcu	*sd;
@@ -2438,3 +2442,33 @@ static inline bool sched_energy_enabled(void)
 static inline bool sched_energy_enabled(void) { return false; }
 
 #endif /* CONFIG_ENERGY_MODEL && CONFIG_CPU_FREQ_GOV_SCHEDUTIL */
+
+#ifdef CONFIG_MEMBARRIER
+/*
+ * The scheduler provides memory barriers required by membarrier between:
+ * - prior user-space memory accesses and store to rq->membarrier_state,
+ * - store to rq->membarrier_state and following user-space memory accesses.
+ * In the same way it provides those guarantees around store to rq->curr.
+ */
+static inline void membarrier_switch_mm(struct rq *rq,
+					struct mm_struct *prev_mm,
+					struct mm_struct *next_mm)
+{
+	int membarrier_state;
+
+	if (prev_mm == next_mm)
+		return;
+
+	membarrier_state = atomic_read(&next_mm->membarrier_state);
+	if (READ_ONCE(rq->membarrier_state) == membarrier_state)
+		return;
+
+	WRITE_ONCE(rq->membarrier_state, membarrier_state);
+}
+#else
+static inline void membarrier_switch_mm(struct rq *rq,
+					struct mm_struct *prev_mm,
+					struct mm_struct *next_mm)
+{
+}
+#endif