Skip to content

Commit 1a54a15

Browse files
committed
std.rand: Refactor Random interface
These changes have been made to resolve issue ziglang#10037. The `Random` interface was implemented in such a way that causes significant slowdown when calling the `fill` function of the rng used. The `Random` interface is no longer stored in a field of the rng, and is instead returned by the child function `random()` of the rng. This avoids the performance issues caused by the interface.
1 parent 3af9731 commit 1a54a15

18 files changed

+291
-244
lines changed

lib/std/atomic/queue.zig

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,11 @@ test "std.atomic.Queue" {
242242

243243
fn startPuts(ctx: *Context) u8 {
244244
var put_count: usize = puts_per_thread;
245-
var r = std.rand.DefaultPrng.init(0xdeadbeef);
245+
var prng = std.rand.DefaultPrng.init(0xdeadbeef);
246+
const random = prng.random();
246247
while (put_count != 0) : (put_count -= 1) {
247248
std.time.sleep(1); // let the os scheduler be our fuzz
248-
const x = @bitCast(i32, r.random.int(u32));
249+
const x = @bitCast(i32, random.int(u32));
249250
const node = ctx.allocator.create(Queue(i32).Node) catch unreachable;
250251
node.* = .{
251252
.prev = undefined,

lib/std/atomic/stack.zig

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,11 @@ test "std.atomic.stack" {
147147

148148
fn startPuts(ctx: *Context) u8 {
149149
var put_count: usize = puts_per_thread;
150-
var r = std.rand.DefaultPrng.init(0xdeadbeef);
150+
var prng = std.rand.DefaultPrng.init(0xdeadbeef);
151+
const random = prng.random();
151152
while (put_count != 0) : (put_count -= 1) {
152153
std.time.sleep(1); // let the os scheduler be our fuzz
153-
const x = @bitCast(i32, r.random.int(u32));
154+
const x = @bitCast(i32, random.int(u32));
154155
const node = ctx.allocator.create(Stack(i32).Node) catch unreachable;
155156
node.* = Stack(i32).Node{
156157
.next = undefined,

lib/std/crypto/benchmark.zig

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const KiB = 1024;
1111
const MiB = 1024 * KiB;
1212

1313
var prng = std.rand.DefaultPrng.init(0);
14+
const random = prng.random();
1415

1516
const Crypto = struct {
1617
ty: type,
@@ -34,7 +35,7 @@ pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64
3435
var h = Hash.init(.{});
3536

3637
var block: [Hash.digest_length]u8 = undefined;
37-
prng.random.bytes(block[0..]);
38+
random.bytes(block[0..]);
3839

3940
var offset: usize = 0;
4041
var timer = try Timer.start();
@@ -66,11 +67,11 @@ const macs = [_]Crypto{
6667

6768
pub fn benchmarkMac(comptime Mac: anytype, comptime bytes: comptime_int) !u64 {
6869
var in: [512 * KiB]u8 = undefined;
69-
prng.random.bytes(in[0..]);
70+
random.bytes(in[0..]);
7071

7172
const key_length = if (Mac.key_length == 0) 32 else Mac.key_length;
7273
var key: [key_length]u8 = undefined;
73-
prng.random.bytes(key[0..]);
74+
random.bytes(key[0..]);
7475

7576
var mac: [Mac.mac_length]u8 = undefined;
7677
var offset: usize = 0;
@@ -94,10 +95,10 @@ pub fn benchmarkKeyExchange(comptime DhKeyExchange: anytype, comptime exchange_c
9495
std.debug.assert(DhKeyExchange.shared_length >= DhKeyExchange.secret_length);
9596

9697
var secret: [DhKeyExchange.shared_length]u8 = undefined;
97-
prng.random.bytes(secret[0..]);
98+
random.bytes(secret[0..]);
9899

99100
var public: [DhKeyExchange.shared_length]u8 = undefined;
100-
prng.random.bytes(public[0..]);
101+
random.bytes(public[0..]);
101102

102103
var timer = try Timer.start();
103104
const start = timer.lap();
@@ -211,15 +212,15 @@ const aeads = [_]Crypto{
211212

212213
pub fn benchmarkAead(comptime Aead: anytype, comptime bytes: comptime_int) !u64 {
213214
var in: [512 * KiB]u8 = undefined;
214-
prng.random.bytes(in[0..]);
215+
random.bytes(in[0..]);
215216

216217
var tag: [Aead.tag_length]u8 = undefined;
217218

218219
var key: [Aead.key_length]u8 = undefined;
219-
prng.random.bytes(key[0..]);
220+
random.bytes(key[0..]);
220221

221222
var nonce: [Aead.nonce_length]u8 = undefined;
222-
prng.random.bytes(nonce[0..]);
223+
random.bytes(nonce[0..]);
223224

224225
var offset: usize = 0;
225226
var timer = try Timer.start();
@@ -244,7 +245,7 @@ const aes = [_]Crypto{
244245

245246
pub fn benchmarkAes(comptime Aes: anytype, comptime count: comptime_int) !u64 {
246247
var key: [Aes.key_bits / 8]u8 = undefined;
247-
prng.random.bytes(key[0..]);
248+
random.bytes(key[0..]);
248249
const ctx = Aes.initEnc(key);
249250

250251
var in = [_]u8{0} ** 16;
@@ -273,7 +274,7 @@ const aes8 = [_]Crypto{
273274

274275
pub fn benchmarkAes8(comptime Aes: anytype, comptime count: comptime_int) !u64 {
275276
var key: [Aes.key_bits / 8]u8 = undefined;
276-
prng.random.bytes(key[0..]);
277+
random.bytes(key[0..]);
277278
const ctx = Aes.initEnc(key);
278279

279280
var in = [_]u8{0} ** (8 * 16);

lib/std/crypto/tlcsprng.zig

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ const os = std.os;
1111

1212
/// We use this as a layer of indirection because global const pointers cannot
1313
/// point to thread-local variables.
14-
pub var interface = std.rand.Random{ .fillFn = tlsCsprngFill };
14+
pub const interface = std.rand.Random{
15+
.ptr = undefined,
16+
.fillFn = tlsCsprngFill,
17+
};
1518

1619
const os_has_fork = switch (builtin.os.tag) {
1720
.dragonfly,
@@ -55,7 +58,7 @@ var install_atfork_handler = std.once(struct {
5558

5659
threadlocal var wipe_mem: []align(mem.page_size) u8 = &[_]u8{};
5760

58-
fn tlsCsprngFill(_: *const std.rand.Random, buffer: []u8) void {
61+
fn tlsCsprngFill(_: *c_void, buffer: []u8) void {
5962
if (builtin.link_libc and @hasDecl(std.c, "arc4random_buf")) {
6063
// arc4random is already a thread-local CSPRNG.
6164
return std.c.arc4random_buf(buffer.ptr, buffer.len);

lib/std/hash/benchmark.zig

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const MiB = 1024 * KiB;
1111
const GiB = 1024 * MiB;
1212

1313
var prng = std.rand.DefaultPrng.init(0);
14+
const random = prng.random();
1415

1516
const Hash = struct {
1617
ty: type,
@@ -88,7 +89,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
8889
};
8990

9091
var block: [block_size]u8 = undefined;
91-
prng.random.bytes(block[0..]);
92+
random.bytes(block[0..]);
9293

9394
var offset: usize = 0;
9495
var timer = try Timer.start();
@@ -110,7 +111,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
110111
pub fn benchmarkHashSmallKeys(comptime H: anytype, key_size: usize, bytes: usize) !Result {
111112
const key_count = bytes / key_size;
112113
var block: [block_size]u8 = undefined;
113-
prng.random.bytes(block[0..]);
114+
random.bytes(block[0..]);
114115

115116
var i: usize = 0;
116117
var timer = try Timer.start();

lib/std/hash_map.zig

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,10 +1795,11 @@ test "std.hash_map put and remove loop in random order" {
17951795
while (i < size) : (i += 1) {
17961796
try keys.append(i);
17971797
}
1798-
var rng = std.rand.DefaultPrng.init(0);
1798+
var prng = std.rand.DefaultPrng.init(0);
1799+
const random = prng.random();
17991800

18001801
while (i < iterations) : (i += 1) {
1801-
std.rand.Random.shuffle(&rng.random, u32, keys.items);
1802+
random.shuffle(u32, keys.items);
18021803

18031804
for (keys.items) |key| {
18041805
try map.put(key, key);
@@ -1826,14 +1827,15 @@ test "std.hash_map remove one million elements in random order" {
18261827
keys.append(i) catch unreachable;
18271828
}
18281829

1829-
var rng = std.rand.DefaultPrng.init(0);
1830-
std.rand.Random.shuffle(&rng.random, u32, keys.items);
1830+
var prng = std.rand.DefaultPrng.init(0);
1831+
const random = prng.random();
1832+
random.shuffle(u32, keys.items);
18311833

18321834
for (keys.items) |key| {
18331835
map.put(key, key) catch unreachable;
18341836
}
18351837

1836-
std.rand.Random.shuffle(&rng.random, u32, keys.items);
1838+
random.shuffle(u32, keys.items);
18371839
i = 0;
18381840
while (i < n) : (i += 1) {
18391841
const key = keys.items[i];

lib/std/io/test.zig

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ test "write a file, read it, then delete it" {
2020

2121
var data: [1024]u8 = undefined;
2222
var prng = DefaultPrng.init(1234);
23-
prng.random.bytes(data[0..]);
23+
const random = prng.random();
24+
random.bytes(data[0..]);
2425
const tmp_file_name = "temp_test_file.txt";
2526
{
2627
var file = try tmp.dir.createFile(tmp_file_name, .{});

lib/std/math/big/rational.zig

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,10 @@ test "big.rational set/to Float round-trip" {
589589
var a = try Rational.init(testing.allocator);
590590
defer a.deinit();
591591
var prng = std.rand.DefaultPrng.init(0x5EED);
592+
const random = prng.random();
592593
var i: usize = 0;
593594
while (i < 512) : (i += 1) {
594-
const r = prng.random.float(f64);
595+
const r = random.float(f64);
595596
try a.setFloat(f64, r);
596597
try testing.expect((try a.toFloat(f64)) == r);
597598
}

lib/std/priority_dequeue.zig

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -850,17 +850,18 @@ test "std.PriorityDequeue: shrinkAndFree" {
850850

851851
test "std.PriorityDequeue: fuzz testing min" {
852852
var prng = std.rand.DefaultPrng.init(0x12345678);
853+
const random = prng.random();
853854

854855
const test_case_count = 100;
855856
const queue_size = 1_000;
856857

857858
var i: usize = 0;
858859
while (i < test_case_count) : (i += 1) {
859-
try fuzzTestMin(&prng.random, queue_size);
860+
try fuzzTestMin(random, queue_size);
860861
}
861862
}
862863

863-
fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void {
864+
fn fuzzTestMin(rng: std.rand.Random, comptime queue_size: usize) !void {
864865
const allocator = testing.allocator;
865866
const items = try generateRandomSlice(allocator, rng, queue_size);
866867

@@ -878,17 +879,18 @@ fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void {
878879

879880
test "std.PriorityDequeue: fuzz testing max" {
880881
var prng = std.rand.DefaultPrng.init(0x87654321);
882+
const random = prng.random();
881883

882884
const test_case_count = 100;
883885
const queue_size = 1_000;
884886

885887
var i: usize = 0;
886888
while (i < test_case_count) : (i += 1) {
887-
try fuzzTestMax(&prng.random, queue_size);
889+
try fuzzTestMax(random, queue_size);
888890
}
889891
}
890892

891-
fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void {
893+
fn fuzzTestMax(rng: std.rand.Random, queue_size: usize) !void {
892894
const allocator = testing.allocator;
893895
const items = try generateRandomSlice(allocator, rng, queue_size);
894896

@@ -906,17 +908,18 @@ fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void {
906908

907909
test "std.PriorityDequeue: fuzz testing min and max" {
908910
var prng = std.rand.DefaultPrng.init(0x87654321);
911+
const random = prng.random();
909912

910913
const test_case_count = 100;
911914
const queue_size = 1_000;
912915

913916
var i: usize = 0;
914917
while (i < test_case_count) : (i += 1) {
915-
try fuzzTestMinMax(&prng.random, queue_size);
918+
try fuzzTestMinMax(random, queue_size);
916919
}
917920
}
918921

919-
fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void {
922+
fn fuzzTestMinMax(rng: std.rand.Random, queue_size: usize) !void {
920923
const allocator = testing.allocator;
921924
const items = try generateRandomSlice(allocator, rng, queue_size);
922925

@@ -943,7 +946,7 @@ fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void {
943946
}
944947
}
945948

946-
fn generateRandomSlice(allocator: *std.mem.Allocator, rng: *std.rand.Random, size: usize) ![]u32 {
949+
fn generateRandomSlice(allocator: *std.mem.Allocator, rng: std.rand.Random, size: usize) ![]u32 {
947950
var array = std.ArrayList(u32).init(allocator);
948951
try array.ensureTotalCapacity(size);
949952

0 commit comments

Comments
 (0)