summary refs log tree commit diff
path: root/lib
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2022-10-05 16:43:38 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2022-10-11 17:42:55 -0600
commit81895a65ec63ee1daec3255dc1a06675d2fbe915 (patch)
tree03703a64413a07bd4e6a171a449a8f18c760e21c /lib
parentd465bff130bf4ca17b6980abe51164ace1e0cba4 (diff)
downloadlinux-81895a65ec63ee1daec3255dc1a06675d2fbe915.tar.gz
treewide: use prandom_u32_max() when possible, part 1
Rather than incurring a division or requesting too many random bytes for
the given range, use the prandom_u32_max() function, which only takes
the minimum required bytes from the RNG and avoids divisions. This was
done mechanically with this coccinelle script:

@basic@
expression E;
type T;
identifier get_random_u32 =~ "get_random_int|prandom_u32|get_random_u32";
typedef u64;
@@
(
- ((T)get_random_u32() % (E))
+ prandom_u32_max(E)
|
- ((T)get_random_u32() & ((E) - 1))
+ prandom_u32_max(E * XXX_MAKE_SURE_E_IS_POW2)
|
- ((u64)(E) * get_random_u32() >> 32)
+ prandom_u32_max(E)
|
- ((T)get_random_u32() & ~PAGE_MASK)
+ prandom_u32_max(PAGE_SIZE)
)

@multi_line@
identifier get_random_u32 =~ "get_random_int|prandom_u32|get_random_u32";
identifier RAND;
expression E;
@@

-       RAND = get_random_u32();
        ... when != RAND
-       RAND %= (E);
+       RAND = prandom_u32_max(E);

// Find a potential literal
@literal_mask@
expression LITERAL;
type T;
identifier get_random_u32 =~ "get_random_int|prandom_u32|get_random_u32";
position p;
@@

        ((T)get_random_u32()@p & (LITERAL))

// Add one to the literal.
@script:python add_one@
literal << literal_mask.LITERAL;
RESULT;
@@

value = None
if literal.startswith('0x'):
        value = int(literal, 16)
elif literal[0] in '123456789':
        value = int(literal, 10)
if value is None:
        print("I don't know how to handle %s" % (literal))
        cocci.include_match(False)
elif value == 2**32 - 1 or value == 2**31 - 1 or value == 2**24 - 1 or value == 2**16 - 1 or value == 2**8 - 1:
        print("Skipping 0x%x for cleanup elsewhere" % (value))
        cocci.include_match(False)
elif value & (value + 1) != 0:
        print("Skipping 0x%x because it's not a power of two minus one" % (value))
        cocci.include_match(False)
elif literal.startswith('0x'):
        coccinelle.RESULT = cocci.make_expr("0x%x" % (value + 1))
else:
        coccinelle.RESULT = cocci.make_expr("%d" % (value + 1))

// Replace the literal mask with the calculated result.
@plus_one@
expression literal_mask.LITERAL;
position literal_mask.p;
expression add_one.RESULT;
identifier FUNC;
@@

-       (FUNC()@p & (LITERAL))
+       prandom_u32_max(RESULT)

@collapse_ret@
type T;
identifier VAR;
expression E;
@@

 {
-       T VAR;
-       VAR = (E);
-       return VAR;
+       return E;
 }

@drop_var@
type T;
identifier VAR;
@@

 {
-       T VAR;
        ... when != VAR
 }

Reviewed-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
Reviewed-by: Kees Cook <keescook@chromium.org>
Reviewed-by: Yury Norov <yury.norov@gmail.com>
Reviewed-by: KP Singh <kpsingh@kernel.org>
Reviewed-by: Jan Kara <jack@suse.cz> # for ext4 and sbitmap
Reviewed-by: Christoph Böhmwalder <christoph.boehmwalder@linbit.com> # for drbd
Acked-by: Jakub Kicinski <kuba@kernel.org>
Acked-by: Heiko Carstens <hca@linux.ibm.com> # for s390
Acked-by: Ulf Hansson <ulf.hansson@linaro.org> # for mmc
Acked-by: Darrick J. Wong <djwong@kernel.org> # for xfs
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'lib')
-rw-r--r--lib/fault-inject.c2
-rw-r--r--lib/find_bit_benchmark.c4
-rw-r--r--lib/kobject.c2
-rw-r--r--lib/reed_solomon/test_rslib.c6
-rw-r--r--lib/sbitmap.c2
-rw-r--r--lib/test-string_helpers.c2
-rw-r--r--lib/test_hexdump.c10
-rw-r--r--lib/test_list_sort.c2
8 files changed, 15 insertions, 15 deletions
diff --git a/lib/fault-inject.c b/lib/fault-inject.c
index 423784d9c058..96e092de5b72 100644
--- a/lib/fault-inject.c
+++ b/lib/fault-inject.c
@@ -139,7 +139,7 @@ bool should_fail(struct fault_attr *attr, ssize_t size)
 			return false;
 	}
 
-	if (attr->probability <= prandom_u32() % 100)
+	if (attr->probability <= prandom_u32_max(100))
 		return false;
 
 	if (!fail_stacktrace(attr))
diff --git a/lib/find_bit_benchmark.c b/lib/find_bit_benchmark.c
index 10754586403b..7c3c011abd29 100644
--- a/lib/find_bit_benchmark.c
+++ b/lib/find_bit_benchmark.c
@@ -174,8 +174,8 @@ static int __init find_bit_test(void)
 	bitmap_zero(bitmap2, BITMAP_LEN);
 
 	while (nbits--) {
-		__set_bit(prandom_u32() % BITMAP_LEN, bitmap);
-		__set_bit(prandom_u32() % BITMAP_LEN, bitmap2);
+		__set_bit(prandom_u32_max(BITMAP_LEN), bitmap);
+		__set_bit(prandom_u32_max(BITMAP_LEN), bitmap2);
 	}
 
 	test_find_next_bit(bitmap, BITMAP_LEN);
diff --git a/lib/kobject.c b/lib/kobject.c
index 5f0e71ab292c..a0b2dbfcfa23 100644
--- a/lib/kobject.c
+++ b/lib/kobject.c
@@ -694,7 +694,7 @@ static void kobject_release(struct kref *kref)
 {
 	struct kobject *kobj = container_of(kref, struct kobject, kref);
 #ifdef CONFIG_DEBUG_KOBJECT_RELEASE
-	unsigned long delay = HZ + HZ * (get_random_int() & 0x3);
+	unsigned long delay = HZ + HZ * prandom_u32_max(4);
 	pr_info("kobject: '%s' (%p): %s, parent %p (delayed %ld)\n",
 		 kobject_name(kobj), kobj, __func__, kobj->parent, delay);
 	INIT_DELAYED_WORK(&kobj->release, kobject_delayed_cleanup);
diff --git a/lib/reed_solomon/test_rslib.c b/lib/reed_solomon/test_rslib.c
index d9d1c33aebda..4d241bdc88aa 100644
--- a/lib/reed_solomon/test_rslib.c
+++ b/lib/reed_solomon/test_rslib.c
@@ -183,7 +183,7 @@ static int get_rcw_we(struct rs_control *rs, struct wspace *ws,
 
 		do {
 			/* Must not choose the same location twice */
-			errloc = prandom_u32() % len;
+			errloc = prandom_u32_max(len);
 		} while (errlocs[errloc] != 0);
 
 		errlocs[errloc] = 1;
@@ -194,12 +194,12 @@ static int get_rcw_we(struct rs_control *rs, struct wspace *ws,
 	for (i = 0; i < eras; i++) {
 		do {
 			/* Must not choose the same location twice */
-			errloc = prandom_u32() % len;
+			errloc = prandom_u32_max(len);
 		} while (errlocs[errloc] != 0);
 
 		derrlocs[i] = errloc;
 
-		if (ewsc && (prandom_u32() & 1)) {
+		if (ewsc && prandom_u32_max(2)) {
 			/* Erasure with the symbol intact */
 			errlocs[errloc] = 2;
 		} else {
diff --git a/lib/sbitmap.c b/lib/sbitmap.c
index a8108a962dfd..055dac069afb 100644
--- a/lib/sbitmap.c
+++ b/lib/sbitmap.c
@@ -33,7 +33,7 @@ static inline unsigned update_alloc_hint_before_get(struct sbitmap *sb,
 
 	hint = this_cpu_read(*sb->alloc_hint);
 	if (unlikely(hint >= depth)) {
-		hint = depth ? prandom_u32() % depth : 0;
+		hint = depth ? prandom_u32_max(depth) : 0;
 		this_cpu_write(*sb->alloc_hint, hint);
 	}
 
diff --git a/lib/test-string_helpers.c b/lib/test-string_helpers.c
index 437d8e6b7cb1..86fadd3ba08c 100644
--- a/lib/test-string_helpers.c
+++ b/lib/test-string_helpers.c
@@ -587,7 +587,7 @@ static int __init test_string_helpers_init(void)
 	for (i = 0; i < UNESCAPE_ALL_MASK + 1; i++)
 		test_string_unescape("unescape", i, false);
 	test_string_unescape("unescape inplace",
-			     get_random_int() % (UNESCAPE_ANY + 1), true);
+			     prandom_u32_max(UNESCAPE_ANY + 1), true);
 
 	/* Without dictionary */
 	for (i = 0; i < ESCAPE_ALL_MASK + 1; i++)
diff --git a/lib/test_hexdump.c b/lib/test_hexdump.c
index 5144899d3c6b..0927f44cd478 100644
--- a/lib/test_hexdump.c
+++ b/lib/test_hexdump.c
@@ -149,7 +149,7 @@ static void __init test_hexdump(size_t len, int rowsize, int groupsize,
 static void __init test_hexdump_set(int rowsize, bool ascii)
 {
 	size_t d = min_t(size_t, sizeof(data_b), rowsize);
-	size_t len = get_random_int() % d + 1;
+	size_t len = prandom_u32_max(d) + 1;
 
 	test_hexdump(len, rowsize, 4, ascii);
 	test_hexdump(len, rowsize, 2, ascii);
@@ -208,11 +208,11 @@ static void __init test_hexdump_overflow(size_t buflen, size_t len,
 static void __init test_hexdump_overflow_set(size_t buflen, bool ascii)
 {
 	unsigned int i = 0;
-	int rs = (get_random_int() % 2 + 1) * 16;
+	int rs = (prandom_u32_max(2) + 1) * 16;
 
 	do {
 		int gs = 1 << i;
-		size_t len = get_random_int() % rs + gs;
+		size_t len = prandom_u32_max(rs) + gs;
 
 		test_hexdump_overflow(buflen, rounddown(len, gs), rs, gs, ascii);
 	} while (i++ < 3);
@@ -223,11 +223,11 @@ static int __init test_hexdump_init(void)
 	unsigned int i;
 	int rowsize;
 
-	rowsize = (get_random_int() % 2 + 1) * 16;
+	rowsize = (prandom_u32_max(2) + 1) * 16;
 	for (i = 0; i < 16; i++)
 		test_hexdump_set(rowsize, false);
 
-	rowsize = (get_random_int() % 2 + 1) * 16;
+	rowsize = (prandom_u32_max(2) + 1) * 16;
 	for (i = 0; i < 16; i++)
 		test_hexdump_set(rowsize, true);
 
diff --git a/lib/test_list_sort.c b/lib/test_list_sort.c
index ade7a1ea0c8e..19ff229b9c3a 100644
--- a/lib/test_list_sort.c
+++ b/lib/test_list_sort.c
@@ -71,7 +71,7 @@ static void list_sort_test(struct kunit *test)
 		KUNIT_ASSERT_NOT_ERR_OR_NULL(test, el);
 
 		 /* force some equivalencies */
-		el->value = prandom_u32() % (TEST_LIST_LEN / 3);
+		el->value = prandom_u32_max(TEST_LIST_LEN / 3);
 		el->serial = i;
 		el->poison1 = TEST_POISON1;
 		el->poison2 = TEST_POISON2;