Skip to content

std.rand: Refactor Random interface #10045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/std/atomic/queue.zig
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,11 @@ test "std.atomic.Queue" {

fn startPuts(ctx: *Context) u8 {
var put_count: usize = puts_per_thread;
var r = std.rand.DefaultPrng.init(0xdeadbeef);
var prng = std.rand.DefaultPrng.init(0xdeadbeef);
const random = prng.random();
while (put_count != 0) : (put_count -= 1) {
std.time.sleep(1); // let the os scheduler be our fuzz
const x = @bitCast(i32, r.random.int(u32));
const x = @bitCast(i32, random.int(u32));
const node = ctx.allocator.create(Queue(i32).Node) catch unreachable;
node.* = .{
.prev = undefined,
Expand Down
5 changes: 3 additions & 2 deletions lib/std/atomic/stack.zig
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ test "std.atomic.stack" {

fn startPuts(ctx: *Context) u8 {
var put_count: usize = puts_per_thread;
var r = std.rand.DefaultPrng.init(0xdeadbeef);
var prng = std.rand.DefaultPrng.init(0xdeadbeef);
const random = prng.random();
while (put_count != 0) : (put_count -= 1) {
std.time.sleep(1); // let the os scheduler be our fuzz
const x = @bitCast(i32, r.random.int(u32));
const x = @bitCast(i32, random.int(u32));
const node = ctx.allocator.create(Stack(i32).Node) catch unreachable;
node.* = Stack(i32).Node{
.next = undefined,
Expand Down
21 changes: 11 additions & 10 deletions lib/std/crypto/benchmark.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const KiB = 1024;
const MiB = 1024 * KiB;

var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();

const Crypto = struct {
ty: type,
Expand All @@ -34,7 +35,7 @@ pub fn benchmarkHash(comptime Hash: anytype, comptime bytes: comptime_int) !u64
var h = Hash.init(.{});

var block: [Hash.digest_length]u8 = undefined;
prng.random.bytes(block[0..]);
random.bytes(block[0..]);

var offset: usize = 0;
var timer = try Timer.start();
Expand Down Expand Up @@ -66,11 +67,11 @@ const macs = [_]Crypto{

pub fn benchmarkMac(comptime Mac: anytype, comptime bytes: comptime_int) !u64 {
var in: [512 * KiB]u8 = undefined;
prng.random.bytes(in[0..]);
random.bytes(in[0..]);

const key_length = if (Mac.key_length == 0) 32 else Mac.key_length;
var key: [key_length]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);

var mac: [Mac.mac_length]u8 = undefined;
var offset: usize = 0;
Expand All @@ -94,10 +95,10 @@ pub fn benchmarkKeyExchange(comptime DhKeyExchange: anytype, comptime exchange_c
std.debug.assert(DhKeyExchange.shared_length >= DhKeyExchange.secret_length);

var secret: [DhKeyExchange.shared_length]u8 = undefined;
prng.random.bytes(secret[0..]);
random.bytes(secret[0..]);

var public: [DhKeyExchange.shared_length]u8 = undefined;
prng.random.bytes(public[0..]);
random.bytes(public[0..]);

var timer = try Timer.start();
const start = timer.lap();
Expand Down Expand Up @@ -211,15 +212,15 @@ const aeads = [_]Crypto{

pub fn benchmarkAead(comptime Aead: anytype, comptime bytes: comptime_int) !u64 {
var in: [512 * KiB]u8 = undefined;
prng.random.bytes(in[0..]);
random.bytes(in[0..]);

var tag: [Aead.tag_length]u8 = undefined;

var key: [Aead.key_length]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);

var nonce: [Aead.nonce_length]u8 = undefined;
prng.random.bytes(nonce[0..]);
random.bytes(nonce[0..]);

var offset: usize = 0;
var timer = try Timer.start();
Expand All @@ -244,7 +245,7 @@ const aes = [_]Crypto{

pub fn benchmarkAes(comptime Aes: anytype, comptime count: comptime_int) !u64 {
var key: [Aes.key_bits / 8]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);
const ctx = Aes.initEnc(key);

var in = [_]u8{0} ** 16;
Expand Down Expand Up @@ -273,7 +274,7 @@ const aes8 = [_]Crypto{

pub fn benchmarkAes8(comptime Aes: anytype, comptime count: comptime_int) !u64 {
var key: [Aes.key_bits / 8]u8 = undefined;
prng.random.bytes(key[0..]);
random.bytes(key[0..]);
const ctx = Aes.initEnc(key);

var in = [_]u8{0} ** (8 * 16);
Expand Down
7 changes: 5 additions & 2 deletions lib/std/crypto/tlcsprng.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ const os = std.os;

/// We use this as a layer of indirection because global const pointers cannot
/// point to thread-local variables.
pub var interface = std.rand.Random{ .fillFn = tlsCsprngFill };
pub const interface = std.rand.Random{
.ptr = undefined,
.fillFn = tlsCsprngFill,
};

const os_has_fork = switch (builtin.os.tag) {
.dragonfly,
Expand Down Expand Up @@ -55,7 +58,7 @@ var install_atfork_handler = std.once(struct {

threadlocal var wipe_mem: []align(mem.page_size) u8 = &[_]u8{};

fn tlsCsprngFill(_: *const std.rand.Random, buffer: []u8) void {
fn tlsCsprngFill(_: *c_void, buffer: []u8) void {
if (builtin.link_libc and @hasDecl(std.c, "arc4random_buf")) {
// arc4random is already a thread-local CSPRNG.
return std.c.arc4random_buf(buffer.ptr, buffer.len);
Expand Down
5 changes: 3 additions & 2 deletions lib/std/hash/benchmark.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const MiB = 1024 * KiB;
const GiB = 1024 * MiB;

var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();

const Hash = struct {
ty: type,
Expand Down Expand Up @@ -88,7 +89,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
};

var block: [block_size]u8 = undefined;
prng.random.bytes(block[0..]);
random.bytes(block[0..]);

var offset: usize = 0;
var timer = try Timer.start();
Expand All @@ -110,7 +111,7 @@ pub fn benchmarkHash(comptime H: anytype, bytes: usize) !Result {
pub fn benchmarkHashSmallKeys(comptime H: anytype, key_size: usize, bytes: usize) !Result {
const key_count = bytes / key_size;
var block: [block_size]u8 = undefined;
prng.random.bytes(block[0..]);
random.bytes(block[0..]);

var i: usize = 0;
var timer = try Timer.start();
Expand Down
12 changes: 7 additions & 5 deletions lib/std/hash_map.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1795,10 +1795,11 @@ test "std.hash_map put and remove loop in random order" {
while (i < size) : (i += 1) {
try keys.append(i);
}
var rng = std.rand.DefaultPrng.init(0);
var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();

while (i < iterations) : (i += 1) {
std.rand.Random.shuffle(&rng.random, u32, keys.items);
random.shuffle(u32, keys.items);

for (keys.items) |key| {
try map.put(key, key);
Expand Down Expand Up @@ -1826,14 +1827,15 @@ test "std.hash_map remove one million elements in random order" {
keys.append(i) catch unreachable;
}

var rng = std.rand.DefaultPrng.init(0);
std.rand.Random.shuffle(&rng.random, u32, keys.items);
var prng = std.rand.DefaultPrng.init(0);
const random = prng.random();
random.shuffle(u32, keys.items);

for (keys.items) |key| {
map.put(key, key) catch unreachable;
}

std.rand.Random.shuffle(&rng.random, u32, keys.items);
random.shuffle(u32, keys.items);
i = 0;
while (i < n) : (i += 1) {
const key = keys.items[i];
Expand Down
3 changes: 2 additions & 1 deletion lib/std/io/test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ test "write a file, read it, then delete it" {

var data: [1024]u8 = undefined;
var prng = DefaultPrng.init(1234);
prng.random.bytes(data[0..]);
const random = prng.random();
random.bytes(data[0..]);
const tmp_file_name = "temp_test_file.txt";
{
var file = try tmp.dir.createFile(tmp_file_name, .{});
Expand Down
3 changes: 2 additions & 1 deletion lib/std/math/big/rational.zig
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,10 @@ test "big.rational set/to Float round-trip" {
var a = try Rational.init(testing.allocator);
defer a.deinit();
var prng = std.rand.DefaultPrng.init(0x5EED);
const random = prng.random();
var i: usize = 0;
while (i < 512) : (i += 1) {
const r = prng.random.float(f64);
const r = random.float(f64);
try a.setFloat(f64, r);
try testing.expect((try a.toFloat(f64)) == r);
}
Expand Down
17 changes: 10 additions & 7 deletions lib/std/priority_dequeue.zig
Original file line number Diff line number Diff line change
Expand Up @@ -850,17 +850,18 @@ test "std.PriorityDequeue: shrinkAndFree" {

test "std.PriorityDequeue: fuzz testing min" {
var prng = std.rand.DefaultPrng.init(0x12345678);
const random = prng.random();

const test_case_count = 100;
const queue_size = 1_000;

var i: usize = 0;
while (i < test_case_count) : (i += 1) {
try fuzzTestMin(&prng.random, queue_size);
try fuzzTestMin(random, queue_size);
}
}

fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void {
fn fuzzTestMin(rng: std.rand.Random, comptime queue_size: usize) !void {
const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size);

Expand All @@ -878,17 +879,18 @@ fn fuzzTestMin(rng: *std.rand.Random, comptime queue_size: usize) !void {

test "std.PriorityDequeue: fuzz testing max" {
var prng = std.rand.DefaultPrng.init(0x87654321);
const random = prng.random();

const test_case_count = 100;
const queue_size = 1_000;

var i: usize = 0;
while (i < test_case_count) : (i += 1) {
try fuzzTestMax(&prng.random, queue_size);
try fuzzTestMax(random, queue_size);
}
}

fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void {
fn fuzzTestMax(rng: std.rand.Random, queue_size: usize) !void {
const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size);

Expand All @@ -906,17 +908,18 @@ fn fuzzTestMax(rng: *std.rand.Random, queue_size: usize) !void {

test "std.PriorityDequeue: fuzz testing min and max" {
var prng = std.rand.DefaultPrng.init(0x87654321);
const random = prng.random();

const test_case_count = 100;
const queue_size = 1_000;

var i: usize = 0;
while (i < test_case_count) : (i += 1) {
try fuzzTestMinMax(&prng.random, queue_size);
try fuzzTestMinMax(random, queue_size);
}
}

fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void {
fn fuzzTestMinMax(rng: std.rand.Random, queue_size: usize) !void {
const allocator = testing.allocator;
const items = try generateRandomSlice(allocator, rng, queue_size);

Expand All @@ -943,7 +946,7 @@ fn fuzzTestMinMax(rng: *std.rand.Random, queue_size: usize) !void {
}
}

fn generateRandomSlice(allocator: *std.mem.Allocator, rng: *std.rand.Random, size: usize) ![]u32 {
fn generateRandomSlice(allocator: *std.mem.Allocator, rng: std.rand.Random, size: usize) ![]u32 {
var array = std.ArrayList(u32).init(allocator);
try array.ensureTotalCapacity(size);

Expand Down
Loading