Skip to content

Commit 3ae6766

Browse files
committed
Windows: don't use CommandLineToArgv to get os.Args
CommandLineToArgV is hosted in shell32.dll, which is very expensive to load. Most go programs never need to load shell32.dll for any other reason than to call this function, so implement the same algorithm directly. This reduces simple process startup time from 22ms to 16ms on my machine.
1 parent e6ec820 commit 3ae6766

File tree

2 files changed

+136
-10
lines changed

2 files changed

+136
-10
lines changed

src/os/exec_windows.go

+81-10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sync/atomic"
1111
"syscall"
1212
"time"
13+
"unicode/utf16"
1314
"unsafe"
1415
)
1516

@@ -94,17 +95,87 @@ func findProcess(pid int) (p *Process, err error) {
9495
return newProcess(pid, uintptr(h)), nil
9596
}
9697

97-
func init() {
98-
var argc int32
99-
cmd := syscall.GetCommandLine()
100-
argv, e := syscall.CommandLineToArgv(cmd, &argc)
101-
if e != nil {
102-
return
98+
func isCmdSpace(c byte) bool {
99+
return c == ' ' || c == '\t'
100+
}
101+
102+
// CommandLineToArgv splits a command line into individual argument strings, following the
103+
// Windows conventions documented at https://msdn.microsoft.com/en-us/library/17w5ykft.aspx.
104+
func CommandLineToArgv(cmd string) []string {
105+
var argv []string
106+
i := 0
107+
for {
108+
for i < len(cmd) && isCmdSpace(cmd[i]) {
109+
i++
110+
}
111+
if i == len(cmd) {
112+
break
113+
}
114+
115+
var arg []byte
116+
inquote := false
117+
for {
118+
nslash := 0
119+
for i < len(cmd) && cmd[i] == '\\' {
120+
i++
121+
nslash++
122+
}
123+
124+
if i < len(cmd) && cmd[i] == '"' {
125+
for s := 0; s < nslash/2; s++ {
126+
arg = append(arg, '\\')
127+
}
128+
129+
if nslash%2 == 0 {
130+
inquote = !inquote
131+
// Special case: if the next character is also a quote,
132+
// then this quote gets included.
133+
if !inquote && i+1 < len(cmd) && cmd[i+1] == '"' {
134+
arg = append(arg, '"')
135+
i++
136+
}
137+
} else {
138+
arg = append(arg, '"')
139+
}
140+
} else {
141+
for nslash > 0 {
142+
arg = append(arg, '\\')
143+
nslash--
144+
}
145+
146+
if i == len(cmd) || (!inquote && isCmdSpace(cmd[i])) {
147+
break
148+
}
149+
150+
arg = append(arg, cmd[i])
151+
}
152+
153+
i++
154+
}
155+
156+
argv = append(argv, string(arg))
103157
}
104-
defer syscall.LocalFree(syscall.Handle(uintptr(unsafe.Pointer(argv))))
105-
Args = make([]string, argc)
106-
for i, v := range (*argv)[:argc] {
107-
Args[i] = syscall.UTF16ToString((*v)[:])
158+
return argv
159+
}
160+
161+
func init() {
162+
cmd := syscall.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(syscall.GetCommandLine()))[:])
163+
164+
if len(cmd) != 0 {
165+
Args = CommandLineToArgv(cmd)
166+
} else {
167+
// No command line was provided, so get argv[0] from the module name.
168+
dll := syscall.MustLoadDLL("kernel32.dll")
169+
defer dll.Release()
170+
fn := dll.MustFindProc("GetModuleFileNameW")
171+
172+
p := make([]uint16, syscall.MAX_PATH)
173+
r, _, err := fn.Call(0, uintptr(unsafe.Pointer(&p[0])), uintptr(len(p)))
174+
n := uint32(r)
175+
if n == 0 || n >= uint32(len(p)) {
176+
panic(err)
177+
}
178+
Args = []string{string(utf16.Decode(p[:n]))}
108179
}
109180
}
110181

src/os/exec_windows_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package os_test
2+
3+
import (
4+
. "os"
5+
"syscall"
6+
"testing"
7+
"unsafe"
8+
)
9+
10+
func TestCommandLineToArgv(t *testing.T) {
11+
cmds := []string{
12+
`test`,
13+
`test a b c`,
14+
`test "`,
15+
`test ""`,
16+
`test """`,
17+
`test "" a`,
18+
`test "123"`,
19+
`test \"123\"`,
20+
`test \"123 456\"`,
21+
`test \\"`,
22+
`test \\\"`,
23+
`test \\\\\"`,
24+
`test \\\"x`,
25+
`test """"\""\\\"`,
26+
`"cmd line" abc`,
27+
`test \\\\\""x"""y z`,
28+
"test\tb\t\"x\ty\"",
29+
}
30+
31+
for _, cmd := range cmds {
32+
var argc int32
33+
c, err := syscall.CommandLineToArgv(&syscall.StringToUTF16(cmd)[0], &argc)
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
38+
out := CommandLineToArgv(cmd)
39+
outwin := make([]string, len(out))
40+
41+
valid := len(outwin) == len(out)
42+
for i := range outwin {
43+
outwin[i] = syscall.UTF16ToString(c[i][:])
44+
if i < len(out) && out[i] != outwin[i] {
45+
valid = false
46+
}
47+
}
48+
49+
if !valid {
50+
t.Errorf("%#v: %#v vs %#v", cmd, out, outwin)
51+
}
52+
53+
syscall.LocalFree(syscall.Handle(unsafe.Pointer(c)))
54+
}
55+
}

0 commit comments

Comments
 (0)