Skip to content

Commit fbdaab9

Browse files
committed
sync: add implementation from upstream Go for OnceFunc, OnceValue, and OnceValues
Signed-off-by: deadprogram <[email protected]>
1 parent f5f7a78 commit fbdaab9

File tree

2 files changed

+256
-0
lines changed

2 files changed

+256
-0
lines changed

src/sync/oncefunc.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright 2022 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package sync
6+
7+
// OnceFunc returns a function that invokes f only once. The returned function
8+
// may be called concurrently.
9+
//
10+
// If f panics, the returned function will panic with the same value on every call.
11+
func OnceFunc(f func()) func() {
12+
var (
13+
once Once
14+
valid bool
15+
p any
16+
)
17+
// Construct the inner closure just once to reduce costs on the fast path.
18+
g := func() {
19+
defer func() {
20+
p = recover()
21+
if !valid {
22+
// Re-panic immediately so on the first call the user gets a
23+
// complete stack trace into f.
24+
panic(p)
25+
}
26+
}()
27+
f()
28+
valid = true // Set only if f does not panic
29+
}
30+
return func() {
31+
once.Do(g)
32+
if !valid {
33+
panic(p)
34+
}
35+
}
36+
}
37+
38+
// OnceValue returns a function that invokes f only once and returns the value
39+
// returned by f. The returned function may be called concurrently.
40+
//
41+
// If f panics, the returned function will panic with the same value on every call.
42+
func OnceValue[T any](f func() T) func() T {
43+
var (
44+
once Once
45+
valid bool
46+
p any
47+
result T
48+
)
49+
g := func() {
50+
defer func() {
51+
p = recover()
52+
if !valid {
53+
panic(p)
54+
}
55+
}()
56+
result = f()
57+
valid = true
58+
}
59+
return func() T {
60+
once.Do(g)
61+
if !valid {
62+
panic(p)
63+
}
64+
return result
65+
}
66+
}
67+
68+
// OnceValues returns a function that invokes f only once and returns the values
69+
// returned by f. The returned function may be called concurrently.
70+
//
71+
// If f panics, the returned function will panic with the same value on every call.
72+
func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) {
73+
var (
74+
once Once
75+
valid bool
76+
p any
77+
r1 T1
78+
r2 T2
79+
)
80+
g := func() {
81+
defer func() {
82+
p = recover()
83+
if !valid {
84+
panic(p)
85+
}
86+
}()
87+
r1, r2 = f()
88+
valid = true
89+
}
90+
return func() (T1, T2) {
91+
once.Do(g)
92+
if !valid {
93+
panic(p)
94+
}
95+
return r1, r2
96+
}
97+
}

src/sync/oncefunc_test.go

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Copyright 2022 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package sync_test
6+
7+
import (
8+
"runtime"
9+
"sync"
10+
"testing"
11+
)
12+
13+
// We assume that the Once.Do tests have already covered parallelism.
14+
15+
func TestOnceFunc(t *testing.T) {
16+
calls := 0
17+
f := sync.OnceFunc(func() { calls++ })
18+
allocs := testing.AllocsPerRun(10, f)
19+
if calls != 1 {
20+
t.Errorf("want calls==1, got %d", calls)
21+
}
22+
if allocs != 0 {
23+
t.Errorf("want 0 allocations per call, got %v", allocs)
24+
}
25+
}
26+
27+
func TestOnceValue(t *testing.T) {
28+
calls := 0
29+
f := sync.OnceValue(func() int {
30+
calls++
31+
return calls
32+
})
33+
allocs := testing.AllocsPerRun(10, func() { f() })
34+
value := f()
35+
if calls != 1 {
36+
t.Errorf("want calls==1, got %d", calls)
37+
}
38+
if value != 1 {
39+
t.Errorf("want value==1, got %d", value)
40+
}
41+
if allocs != 0 {
42+
t.Errorf("want 0 allocations per call, got %v", allocs)
43+
}
44+
}
45+
46+
func TestOnceValues(t *testing.T) {
47+
calls := 0
48+
f := sync.OnceValues(func() (int, int) {
49+
calls++
50+
return calls, calls + 1
51+
})
52+
allocs := testing.AllocsPerRun(10, func() { f() })
53+
v1, v2 := f()
54+
if calls != 1 {
55+
t.Errorf("want calls==1, got %d", calls)
56+
}
57+
if v1 != 1 || v2 != 2 {
58+
t.Errorf("want v1==1 and v2==2, got %d and %d", v1, v2)
59+
}
60+
if allocs != 0 {
61+
t.Errorf("want 0 allocations per call, got %v", allocs)
62+
}
63+
}
64+
65+
func testOncePanicX(t *testing.T, calls *int, f func()) {
66+
testOncePanicWith(t, calls, f, func(label string, p any) {
67+
if p != "x" {
68+
t.Fatalf("%s: want panic %v, got %v", label, "x", p)
69+
}
70+
})
71+
}
72+
73+
func testOncePanicWith(t *testing.T, calls *int, f func(), check func(label string, p any)) {
74+
// Check that the each call to f panics with the same value, but the
75+
// underlying function is only called once.
76+
for _, label := range []string{"first time", "second time"} {
77+
var p any
78+
panicked := true
79+
func() {
80+
defer func() {
81+
p = recover()
82+
}()
83+
f()
84+
panicked = false
85+
}()
86+
if !panicked {
87+
t.Fatalf("%s: f did not panic", label)
88+
}
89+
check(label, p)
90+
}
91+
if *calls != 1 {
92+
t.Errorf("want calls==1, got %d", *calls)
93+
}
94+
}
95+
96+
func TestOnceFuncPanic(t *testing.T) {
97+
calls := 0
98+
f := sync.OnceFunc(func() {
99+
calls++
100+
panic("x")
101+
})
102+
testOncePanicX(t, &calls, f)
103+
}
104+
105+
func TestOnceValuePanic(t *testing.T) {
106+
calls := 0
107+
f := sync.OnceValue(func() int {
108+
calls++
109+
panic("x")
110+
})
111+
testOncePanicX(t, &calls, func() { f() })
112+
}
113+
114+
func TestOnceValuesPanic(t *testing.T) {
115+
calls := 0
116+
f := sync.OnceValues(func() (int, int) {
117+
calls++
118+
panic("x")
119+
})
120+
testOncePanicX(t, &calls, func() { f() })
121+
}
122+
123+
func TestOnceFuncPanicNil(t *testing.T) {
124+
calls := 0
125+
f := sync.OnceFunc(func() {
126+
calls++
127+
panic(nil)
128+
})
129+
testOncePanicWith(t, &calls, f, func(label string, p any) {
130+
switch p.(type) {
131+
case nil, *runtime.PanicNilError:
132+
return
133+
}
134+
t.Fatalf("%s: want nil panic, got %v", label, p)
135+
})
136+
}
137+
138+
func TestOnceFuncGoexit(t *testing.T) {
139+
// If f calls Goexit, the results are unspecified. But check that f doesn't
140+
// get called twice.
141+
calls := 0
142+
f := sync.OnceFunc(func() {
143+
calls++
144+
runtime.Goexit()
145+
})
146+
var wg sync.WaitGroup
147+
for i := 0; i < 2; i++ {
148+
wg.Add(1)
149+
go func() {
150+
defer wg.Done()
151+
defer func() { recover() }()
152+
f()
153+
}()
154+
wg.Wait()
155+
}
156+
if calls != 1 {
157+
t.Errorf("want calls==1, got %d", calls)
158+
}
159+
}

0 commit comments

Comments
 (0)