diff --git a/conn.go b/conn.go index 1d9e4a22..707c3ae8 100644 --- a/conn.go +++ b/conn.go @@ -22,6 +22,7 @@ import ( "time" "unicode" + "github.com/lib/pq/internal" "github.com/lib/pq/oid" "github.com/lib/pq/scram" ) @@ -857,7 +858,7 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { } defer cn.errRecover(&err) - if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { + if len(q) >= 4 && internal.StartsWithCOPY(q[:4]) { s, err := cn.prepareCopyIn(q) if err == nil { cn.inCopy = true diff --git a/internal/copy.go b/internal/copy.go new file mode 100644 index 00000000..b9337212 --- /dev/null +++ b/internal/copy.go @@ -0,0 +1,102 @@ +package internal + +func consumeWhitespace(input string, i int, n int) int { + for i < n && (input[i] == ' ' || input[i] == '\t' || input[i] == '\n' || input[i] == '\r') { + i++ + } + return i +} + +func consumeLineComment(input string, i int, n int) int { + i += 2 // skip '--' + for i < n && input[i] != '\n' { + i++ + } + return i +} + +func consumeBlockComment(input string, i int, n int) int { + i += 2 // skip '/*' + for i < n-1 { + if input[i] == '*' && input[i+1] == '/' { + i += 2 + return i + } + i++ + } + // Unterminated comment? Consider as done consuming. + return i +} + +func StartsWithCOPY(input string) bool { + const ( + Start = iota + WhitespaceOrComment + C + O + P + Y + Done + Fail + ) + + state := Start + i := 0 + n := len(input) + + for state != Done && state != Fail { + if i >= n { + if state == Y { + state = Done + } else { + state = Fail + } + break + } + + switch state { + case Start, WhitespaceOrComment: + i = consumeWhitespace(input, i, n) + if i+1 < n && input[i] == '-' && input[i+1] == '-' { + i = consumeLineComment(input, i, n) + } else if i+1 < n && input[i] == '/' && input[i+1] == '*' { + i = consumeBlockComment(input, i, n) + } else if i < n { + switch input[i] { + case 'C', 'c': + state = C + i++ + case ' ', '\t', '\n', '\r': + // handled in consumeWhitespace + default: + state = Fail + } + } + case C: + if i < n && (input[i] == 'O' || input[i] == 'o') { + state = O + i++ + } else { + state = Fail + } + case O: + if i < n && (input[i] == 'P' || input[i] == 'p') { + state = P + i++ + } else { + state = Fail + } + case P: + if i < n && (input[i] == 'Y' || input[i] == 'y') { + state = Y + i++ + } else { + state = Fail + } + case Y: + state = Done + } + } + + return state == Done +} diff --git a/internal/copy_test.go b/internal/copy_test.go new file mode 100644 index 00000000..4e22c3a9 --- /dev/null +++ b/internal/copy_test.go @@ -0,0 +1,69 @@ +package internal + +import ( + "strings" + "testing" +) + +func TestStartsWithCOPY(t *testing.T) { + tests := []struct { + input string + valid bool + }{ + { + input: "COPY data;", + valid: true, + }, + { + input: " COPY", + valid: true, + }, + { + input: "SELECT * FROM users;", + valid: false, + }, + { + input: "-- comment only\n/* and another */COPY table", + valid: true, + }, + { + input: "\n\n/* header */ COPY my_table FROM stdin;", + valid: true, + }, + { + input: " -- some comment\n /* block */ COPY table FROM stdin;", + valid: true, + }, + { + input: "-- some comment not terminated on purpose (or not) COPY table FROM stdin;", + valid: false, + }, + { + input: "-- COPY table FROM stdin;\nSELECT * FROM users;", + valid: false, + }, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + valid := StartsWithCOPY(test.input) + if valid != test.valid { + t.Errorf("Expected %q to be %v, got %v", test.input, test.valid, valid) + } + }) + } +} + +func BenchmarkStartsWithCOPY(b *testing.B) { + sql := " -- comment\n /* block */ COPY table FROM stdin;" + for i := 0; i < b.N; i++ { + _ = StartsWithCOPY(sql) + } +} + +func BenchmarkEqualFold(b *testing.B) { + sql := "COPY table FROM stdin;" + for i := 0; i < b.N; i++ { + _ = strings.EqualFold(sql, "COPY") + } +}