summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--arch/arm64/include/asm/fpsimd.h4
-rw-r--r--arch/arm64/kernel/fpsimd.c6
-rw-r--r--arch/arm64/kernel/ptrace.c9
-rw-r--r--arch/arm64/kernel/signal.c2
4 files changed, 14 insertions, 7 deletions
diff --git a/arch/arm64/include/asm/fpsimd.h b/arch/arm64/include/asm/fpsimd.h
index 6f86b7ab6c28..d720b6f7e5f9 100644
--- a/arch/arm64/include/asm/fpsimd.h
+++ b/arch/arm64/include/asm/fpsimd.h
@@ -339,7 +339,7 @@ static inline int sme_max_virtualisable_vl(void)
 	return vec_max_virtualisable_vl(ARM64_VEC_SME);
 }
 
-extern void sme_alloc(struct task_struct *task);
+extern void sme_alloc(struct task_struct *task, bool flush);
 extern unsigned int sme_get_vl(void);
 extern int sme_set_current_vl(unsigned long arg);
 extern int sme_get_current_vl(void);
@@ -365,7 +365,7 @@ static inline void sme_smstart_sm(void) { }
 static inline void sme_smstop_sm(void) { }
 static inline void sme_smstop(void) { }
 
-static inline void sme_alloc(struct task_struct *task) { }
+static inline void sme_alloc(struct task_struct *task, bool flush) { }
 static inline void sme_setup(void) { }
 static inline unsigned int sme_get_vl(void) { return 0; }
 static inline int sme_max_vl(void) { return 0; }
diff --git a/arch/arm64/kernel/fpsimd.c b/arch/arm64/kernel/fpsimd.c
index 356036babd09..8cd59d387b90 100644
--- a/arch/arm64/kernel/fpsimd.c
+++ b/arch/arm64/kernel/fpsimd.c
@@ -1239,9 +1239,9 @@ void fpsimd_release_task(struct task_struct *dead_task)
  * the interest of testability and predictability, the architecture
  * guarantees that when ZA is enabled it will be zeroed.
  */
-void sme_alloc(struct task_struct *task)
+void sme_alloc(struct task_struct *task, bool flush)
 {
-	if (task->thread.za_state) {
+	if (task->thread.za_state && flush) {
 		memset(task->thread.za_state, 0, za_state_size(task));
 		return;
 	}
@@ -1460,7 +1460,7 @@ void do_sme_acc(unsigned long esr, struct pt_regs *regs)
 	}
 
 	sve_alloc(current, false);
-	sme_alloc(current);
+	sme_alloc(current, true);
 	if (!current->thread.sve_state || !current->thread.za_state) {
 		force_sig(SIGKILL);
 		return;
diff --git a/arch/arm64/kernel/ptrace.c b/arch/arm64/kernel/ptrace.c
index f19f020ccff9..f606c942f514 100644
--- a/arch/arm64/kernel/ptrace.c
+++ b/arch/arm64/kernel/ptrace.c
@@ -886,6 +886,13 @@ static int sve_set_common(struct task_struct *target,
 			break;
 		case ARM64_VEC_SME:
 			target->thread.svcr |= SVCR_SM_MASK;
+
+			/*
+			 * Disable traps and ensure there is SME storage but
+			 * preserve any currently set values in ZA/ZT.
+			 */
+			sme_alloc(target, false);
+			set_tsk_thread_flag(target, TIF_SME);
 			break;
 		default:
 			WARN_ON_ONCE(1);
@@ -1107,7 +1114,7 @@ static int za_set(struct task_struct *target,
 	}
 
 	/* Allocate/reinit ZA storage */
-	sme_alloc(target);
+	sme_alloc(target, true);
 	if (!target->thread.za_state) {
 		ret = -ENOMEM;
 		goto out;
diff --git a/arch/arm64/kernel/signal.c b/arch/arm64/kernel/signal.c
index 43adbfa5ead7..82f4572c8ddf 100644
--- a/arch/arm64/kernel/signal.c
+++ b/arch/arm64/kernel/signal.c
@@ -430,7 +430,7 @@ static int restore_za_context(struct user_ctxs *user)
 	fpsimd_flush_task_state(current);
 	/* From now, fpsimd_thread_switch() won't touch thread.sve_state */
 
-	sme_alloc(current);
+	sme_alloc(current, true);
 	if (!current->thread.za_state) {
 		current->thread.svcr &= ~SVCR_ZA_MASK;
 		clear_thread_flag(TIF_SME);