From 049563c141703d12ca6f5f1679d492ad31bbb8b2 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 8 Aug 2023 16:27:26 -0500 Subject: [PATCH 1/3] Feat: Add -k and finish -r,-X,-L --- README.md | 2 - cmd/sqlcmd/sqlcmd.go | 160 +++++++++++++++++++++++++++--------- cmd/sqlcmd/sqlcmd_test.go | 44 ++++++---- internal/sql/mssql.go | 2 +- pkg/sqlcmd/commands_test.go | 2 +- pkg/sqlcmd/format.go | 3 +- pkg/sqlcmd/sqlcmd_test.go | 6 +- 7 files changed, 157 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index d78683e3..97a097d8 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,6 @@ The `sqlcmd` project aims to be a complete port of the original ODBC sqlcmd to t - There are new posix-style versions of each flag, such as `--input-file` for `-i`. `sqlcmd -?` will print those parameter names. Those new names do not preserve backward compatibility with ODBC `sqlcmd`. For example, to specify multiple input file names using `--input-file`, the file names must be comma-delimited, not space-delimited. The following switches have different behavior in this version of `sqlcmd` compared to the original ODBC based `sqlcmd`. -- `-r` requires a 0 or 1 argument - `-R` switch is ignored. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms. - `-I` switch is ignored; quoted identifiers are always set on. To disable quoted identifier behavior, add `SET QUOTED IDENTIFIER OFF` in your scripts. - `-N` now takes a string value that can be one of `true`, `false`, or `disable` to specify the encryption choice. @@ -141,7 +140,6 @@ The following switches have different behavior in this version of `sqlcmd` compa - If using a single `-i` flag to pass multiple file names, there must be a space after the `-i`. Example: `-i file1.sql file2.sql` - `-M` switch is ignored. Sqlcmd always enables multi-subnet failover. - ### Switches not available in the new sqlcmd (go-sqlcmd) yet There are a few switches yet to be implemented in the new `sqlcmd` (go-sqlcmd) compared diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index e41abf79..7c8958c5 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -38,8 +38,8 @@ type SQLCmdArguments struct { // Query to run then exit Query string Server string - // Disable syscommands with a warning - DisableCmdAndWarn bool + // Disable syscommands with a warning or error + DisableCmd *int // AuthenticationMethod is new for go-sqlcmd AuthenticationMethod string UseAad bool @@ -55,7 +55,7 @@ type SQLCmdArguments struct { ErrorSeverityLevel uint8 ErrorLevel int Format string - ErrorsToStderr int + ErrorsToStderr *int Headers int UnicodeOutputFile bool Version bool @@ -66,25 +66,58 @@ type SQLCmdArguments struct { TrimSpaces bool Password string DedicatedAdminConnection bool - ListServers bool + ListServers string + RemoveControlCharacters *int // Keep Help at the end of the list Help bool } +func (args *SQLCmdArguments) useEnvVars() bool { + return args.DisableCmd == nil +} + +func (args *SQLCmdArguments) errorOnBlockedCmd() bool { + return args.DisableCmd != nil && *args.DisableCmd > 0 +} + +func (args *SQLCmdArguments) warnOnBlockedCmd() bool { + return args.DisableCmd != nil && *args.DisableCmd <= 0 +} + +func (args *SQLCmdArguments) runStartupScript() bool { + return args.DisableCmd == nil +} + +func (args *SQLCmdArguments) getControlCharacterBehavior() sqlcmd.ControlCharacterBehavior { + if args.RemoveControlCharacters == nil { + return sqlcmd.ControlIgnore + } + switch *args.RemoveControlCharacters { + case 1: + return sqlcmd.ControlReplace + case 2: + return sqlcmd.ControlReplaceConsecutive + } + return sqlcmd.ControlRemove +} + const ( - sqlcmdErrorPrefix = "Sqlcmd: " - applicationIntent = "application-intent" - errorsToStderr = "errors-to-stderr" - format = "format" - encryptConnection = "encrypt-connection" - screenWidth = "screen-width" - fixedTypeWidth = "fixed-type-width" - variableTypeWidth = "variable-type-width" + sqlcmdErrorPrefix = "Sqlcmd: " + applicationIntent = "application-intent" + errorsToStderr = "errors-to-stderr" + format = "format" + encryptConnection = "encrypt-connection" + screenWidth = "screen-width" + fixedTypeWidth = "fixed-type-width" + variableTypeWidth = "variable-type-width" + disableCmdAndWarn = "disable-cmd-and-warn" + listServers = "list-servers" + removeControlCharacters = "remove-control-characters" ) // Validate arguments for settings not describe func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { - if a.ListServers { + if a.ListServers != "" { c.Flags().Visit(func(f *pflag.Flag) { if f.Shorthand != "L" { err = localizer.Errorf("The -L parameter can not be used in combination with other parameters.") @@ -170,9 +203,11 @@ func Execute(version string) { }, Run: func(cmd *cobra.Command, argss []string) { // emulate -L returning no servers - if args.ListServers { - fmt.Println() - fmt.Println(localizer.Sprintf("Servers:")) + if args.ListServers != "" { + if args.ListServers != "c" { + fmt.Println() + fmt.Println(localizer.Sprintf("Servers:")) + } fmt.Println(" ;UID:Login ID=?;PWD:Password=?;Trusted_Connection:Use Integrated Security=?;*APP:AppName=?;*WSID:WorkStation ID=?;") os.Exit(0) } @@ -181,7 +216,7 @@ func Execute(version string) { os.Exit(1) } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) setVars(vars, &args) if args.Version { @@ -225,10 +260,11 @@ func Execute(version string) { } // We need to rewrite the arguments to add -i and -v in front of each space-delimited value to be Cobra-friendly. +// For flags like -r we need to inject the default value if the user omits it func convertOsArgs(args []string) (cargs []string) { flag := "" first := true - for _, a := range args { + for i, a := range args { if flag != "" { // If the user has a file named "-i" the only way they can pass it on the command line // is with triple quotes: sqlcmd -i """-i""" which will convince the flags parser to @@ -240,11 +276,34 @@ func convertOsArgs(args []string) (cargs []string) { } first = false } + var defValue string if isListFlag(a) { flag = a first = true + } else { + defValue = checkDefaultValue(args, i) } cargs = append(cargs, a) + if defValue != "" { + cargs = append(cargs, defValue) + } + } + return +} + +// If args[i] is the given flag and args[i+1] is another flag, returns the value to append after the flag +func checkDefaultValue(args []string, i int) (val string) { + flags := map[rune]string{ + 'r': "0", + 'k': "0", + 'L': "|", // | is the sentinel for no value since users are unlikely to use it. It's "reserved" in most shells + 'X': "0", + } + if isFlag(args[i]) && (len(args) == i+1 || args[i+1][0] == '-') { + if v, ok := flags[rune(args[i][1])]; ok { + val = v + return + } } return } @@ -296,6 +355,9 @@ func SetScreenWidthFlags(args *SQLCmdArguments, rootCmd *cobra.Command) { args.ScreenWidth = getOptionalIntArgument(rootCmd, screenWidth) args.FixedTypeWidth = getOptionalIntArgument(rootCmd, fixedTypeWidth) args.VariableTypeWidth = getOptionalIntArgument(rootCmd, variableTypeWidth) + args.DisableCmd = getOptionalIntArgument(rootCmd, disableCmdAndWarn) + args.ErrorsToStderr = getOptionalIntArgument(rootCmd, errorsToStderr) + args.RemoveControlCharacters = getOptionalIntArgument(rootCmd, removeControlCharacters) } func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { @@ -313,7 +375,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().StringVarP(&args.InitialQuery, "initial-query", "q", "", localizer.Sprintf("Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed")) rootCmd.Flags().StringVarP(&args.Query, "query", "Q", "", localizer.Sprintf("Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed")) rootCmd.Flags().StringVarP(&args.Server, "server", "S", "", localizer.Sprintf("%s Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable %s.", localizer.ConnStrPattern, localizer.ServerEnvVar)) - rootCmd.Flags().BoolVarP(&args.DisableCmdAndWarn, "disable-cmd-and-warn", "X", false, localizer.Sprintf("Disables commands that might compromise system security. Sqlcmd issues a warning and continues")) + _ = rootCmd.Flags().IntP(disableCmdAndWarn, "X", 0, localizer.Sprintf("%s Disables commands that might compromise system security. Passing 1 tells sqlcmd to exit when disabled commands are run.", "-X[1]")) rootCmd.Flags().StringVar(&args.AuthenticationMethod, "authentication-method", "", localizer.Sprintf("Specifies the SQL authentication method to use to connect to Azure SQL Database. One of: ActiveDirectoryDefault, ActiveDirectoryIntegrated, ActiveDirectoryPassword, ActiveDirectoryInteractive, ActiveDirectoryManagedIdentity, ActiveDirectoryServicePrincipal, SqlPassword")) rootCmd.Flags().BoolVarP(&args.UseAad, "use-aad", "G", false, localizer.Sprintf("Tells sqlcmd to use ActiveDirectory authentication. If no user name is provided, authentication method ActiveDirectoryDefault is used. If a password is provided, ActiveDirectoryPassword is used. Otherwise ActiveDirectoryInteractive is used")) rootCmd.Flags().BoolVarP(&args.DisableVariableSubstitution, "disable-variable-substitution", "x", false, localizer.Sprintf("Causes sqlcmd to ignore scripting variables. This parameter is useful when a script contains many %s statements that may contain strings that have the same format as regular variables, such as $(variable_name)", localizer.InsertKeyword)) @@ -328,8 +390,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { // Can't use NoOptDefVal until this fix: https://github.com/spf13/cobra/issues/866 //rootCmd.Flags().Lookup(encryptConnection).NoOptDefVal = "true" rootCmd.Flags().StringVarP(&args.Format, format, "F", "horiz", localizer.Sprintf("Specifies the formatting for results")) - rootCmd.Flags().IntVarP(&args.ErrorsToStderr, errorsToStderr, "r", -1, localizer.Sprintf("Controls which error messages are sent to stdout. Messages that have severity level greater than or equal to this level are sent")) - //rootCmd.Flags().Lookup(errorsToStderr).NoOptDefVal = "0" + _ = rootCmd.Flags().IntP(errorsToStderr, "r", -1, localizer.Sprintf("%s Redirects error messages with severity >= 11 output to stderr. Pass 1 to to redirect all errors including PRINT.", "-r[0 | 1]")) rootCmd.Flags().IntVar(&args.DriverLoggingLevel, "driver-logging-level", 0, localizer.Sprintf("Level of mssql driver messages to print")) rootCmd.Flags().BoolVarP(&args.ExitOnError, "exit-on-error", "b", false, localizer.Sprintf("Specifies that sqlcmd exits and returns a %s value when an error occurs", localizer.DosErrorLevel)) rootCmd.Flags().IntVarP(&args.ErrorLevel, "error-level", "m", 0, localizer.Sprintf("Controls which error messages are sent to %s. Messages that have severity level greater than or equal to this level are sent", localizer.StdoutName)) @@ -350,10 +411,11 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { _ = rootCmd.Flags().IntP(screenWidth, "w", 0, localizer.Sprintf("Specifies the screen width for output")) _ = rootCmd.Flags().IntP(variableTypeWidth, "y", 256, setScriptVariable("SQLCMDMAXVARTYPEWIDTH")) _ = rootCmd.Flags().IntP(fixedTypeWidth, "Y", 0, setScriptVariable("SQLCMDMAXFIXEDTYPEWIDTH")) - rootCmd.Flags().BoolVarP(&args.ListServers, "list-servers", "L", false, "List servers") + rootCmd.Flags().StringVarP(&args.ListServers, listServers, "L", "", localizer.Sprintf("%s List servers. Pass %s to omit 'Servers:' output.", "-L[c]", "c")) rootCmd.Flags().BoolVarP(&args.DedicatedAdminConnection, "dedicated-admin-connection", "A", false, localizer.Sprintf("Dedicated administrator connection")) _ = rootCmd.Flags().BoolP("enable-quoted-identifiers", "I", true, localizer.Sprintf("Provided for backward compatibility. Quoted identifiers are always enabled")) _ = rootCmd.Flags().BoolP("client-regional-setting", "R", false, localizer.Sprintf("Provided for backward compatibility. Client regional settings are not used")) + _ = rootCmd.Flags().IntP(removeControlCharacters, "k", 0, localizer.Sprintf("%s Remove control characters from output. Pass 1 to substitute a space per character, 2 for a space per consecutive characters", "-k [1|2]")) } func setScriptVariable(v string) string { @@ -403,7 +465,32 @@ func normalizeFlags(cmd *cobra.Command) error { err = invalidParameterError("-r", v, "0", "1") return pflag.NormalizedName("") } + case disableCmdAndWarn: + switch v { + case "0", "1": + return pflag.NormalizedName(name) + default: + err = invalidParameterError("-X", v, "1") + return pflag.NormalizedName("") + } + case listServers: + switch v { + case "|", "c": + return pflag.NormalizedName(name) + default: + err = invalidParameterError("-L", v, "c") + return pflag.NormalizedName("") + } + case removeControlCharacters: + switch v { + case "0", "1", "2": + return pflag.NormalizedName(name) + default: + err = invalidParameterError("-k", v, "1", "2") + return pflag.NormalizedName("") + } } + return pflag.NormalizedName(name) }) if err != nil { @@ -558,7 +645,7 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq connect.ApplicationName = "sqlcmd" if len(args.Password) > 0 { connect.Password = args.Password - } else if !args.DisableCmdAndWarn { + } else if args.useEnvVars() { connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD) } connect.ServerName = args.Server @@ -576,7 +663,7 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq connect.UseTrustedConnection = args.UseTrustedConnection connect.TrustServerCertificate = args.TrustServerCertificate connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "") - connect.DisableEnvironmentVariables = args.DisableCmdAndWarn + connect.DisableEnvironmentVariables = !args.useEnvVars() connect.DisableVariableSubstitution = args.DisableVariableSubstitution connect.ApplicationIntent = args.ApplicationIntent connect.LoginTimeoutSeconds = args.LoginTimeout @@ -614,8 +701,8 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { defer s.StopCloseHandler() s.UnicodeOutputFile = args.UnicodeOutputFile - if args.DisableCmdAndWarn { - s.Cmd.DisableSysCommands(false) + if args.DisableCmd != nil { + s.Cmd.DisableSysCommands(args.errorOnBlockedCmd()) } if args.BatchTerminator != "GO" { @@ -629,25 +716,24 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } s.Connect = &connectConfig - s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces) + s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior()) if args.OutputFile != "" { err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) if err != nil { return 1, err } - } else { + } else if args.ErrorsToStderr != nil { var stderrSeverity uint8 = 11 - if args.ErrorsToStderr == 1 { + if *args.ErrorsToStderr == 1 { stderrSeverity = 0 } - if args.ErrorsToStderr >= 0 { - s.PrintError = func(msg string, severity uint8) bool { - if severity >= stderrSeverity { - s.WriteError(os.Stderr, errors.New(msg+sqlcmd.SqlcmdEol)) - return true - } - return false + + s.PrintError = func(msg string, severity uint8) bool { + if severity >= stderrSeverity { + s.WriteError(os.Stderr, errors.New(msg+sqlcmd.SqlcmdEol)) + return true } + return false } } @@ -659,7 +745,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } script := vars.StartupScriptFile() - if !args.DisableCmdAndWarn && len(script) > 0 { + if args.runStartupScript() && len(script) > 0 { f, fileErr := os.Open(script) if fileErr != nil { s.WriteError(s.GetError(), sqlcmd.InvalidVariableValue(sqlcmd.SQLCMDINI, script)) diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 44dcd0db..a38b7f65 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -29,7 +29,7 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { // The long flag names are up for debate. commands := []cmdLineTest{ {[]string{}, func(args SQLCmdArguments) bool { - return args.Server == "" && !args.UseTrustedConnection && args.UserName == "" && args.ScreenWidth == nil && args.ErrorsToStderr == -1 && args.EncryptConnection == "default" + return args.Server == "" && !args.UseTrustedConnection && args.UserName == "" && args.ScreenWidth == nil && args.ErrorsToStderr == nil && args.EncryptConnection == "default" }}, {[]string{"-v", "a=b", "x=y", "-E"}, func(args SQLCmdArguments) bool { return len(args.Variables) == 2 && args.Variables["a"] == "b" && args.Variables["x"] == "y" && args.UseTrustedConnection @@ -50,8 +50,8 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-S", "tcp:someserver,10245"}, func(args SQLCmdArguments) bool { return args.Server == "tcp:someserver,10245" && !args.DisableVariableSubstitution }}, - {[]string{"-X", "-x"}, func(args SQLCmdArguments) bool { - return args.DisableCmdAndWarn && args.DisableVariableSubstitution + {[]string{"-X", "1", "-x"}, func(args SQLCmdArguments) bool { + return args.errorOnBlockedCmd() && args.DisableVariableSubstitution }}, // Notice no "" around the value with a space in it. It seems quotes get stripped out somewhere before Parse when invoking on a real command line {[]string{"-v", "x=y", "-v", `y=a space`}, func(args SQLCmdArguments) bool { @@ -67,7 +67,7 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { return args.Format == "vert" }}, {[]string{"-r", "1"}, func(args SQLCmdArguments) bool { - return args.ErrorsToStderr == 1 + return *args.ErrorsToStderr == 1 }}, {[]string{"-h", "2", "-?"}, func(args SQLCmdArguments) bool { return args.Help && args.Headers == 2 @@ -96,6 +96,9 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-i", `"comma,text.sql"`}, func(args SQLCmdArguments) bool { return args.InputFile[0] == "comma,text.sql" }}, + {[]string{"-k", "-X", "-r"}, func(args SQLCmdArguments) bool { + return args.warnOnBlockedCmd() && !args.useEnvVars() && args.getControlCharacterBehavior() == sqlcmd.ControlRemove && *args.ErrorsToStderr == 0 + }}, } for _, test := range commands { @@ -169,7 +172,7 @@ func TestInvalidCommandLine(t *testing.T) { buf := &memoryBuffer{buf: new(bytes.Buffer)} cmd.SetErr(buf) setFlags(cmd, arguments) - cmd.SetArgs(test.commandLine) + cmd.SetArgs(convertOsArgs(test.commandLine)) err := cmd.Execute() if assert.EqualErrorf(t, err, test.errorMessage, "Command line: %s", test.commandLine) { errBytes := buf.buf.String() @@ -208,7 +211,7 @@ func TestValidateFlags(t *testing.T) { } cmd.SetErr(new(bytes.Buffer)) setFlags(cmd, arguments) - cmd.SetArgs(test.commandLine) + cmd.SetArgs(convertOsArgs(test.commandLine)) err := cmd.Execute() assert.EqualError(t, err, test.errorMessage, "Command line:%v", test.commandLine) } @@ -226,7 +229,7 @@ func TestRunInputFiles(t *testing.T) { if canTestAzureAuth() { args.UseAad = true } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") setVars(vars, &args) @@ -251,7 +254,7 @@ func TestUnicodeOutput(t *testing.T) { if canTestAzureAuth() { args.UseAad = true } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) setVars(vars, &args) exitCode, err := run(vars, &args) @@ -303,7 +306,7 @@ func TestUnicodeInput(t *testing.T) { if canTestAzureAuth() { args.UseAad = true } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) setVars(vars, &args) exitCode, err := run(vars, &args) assert.NoError(t, err, "run") @@ -333,7 +336,7 @@ func TestQueryAndExit(t *testing.T) { if canTestAzureAuth() { args.UseAad = true } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") vars.Set("VAR1", "100") setVars(vars, &args) @@ -353,13 +356,14 @@ func TestQueryAndExit(t *testing.T) { func TestExitOnError(t *testing.T) { args = newArguments() args.InputFile = []string{"testdata/select100.sql"} - args.ErrorsToStderr = 0 + args.ErrorsToStderr = new(int) + *args.ErrorsToStderr = 0 args.ExitOnError = true if canTestAzureAuth() { args.UseAad = true } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) setVars(vars, &args) exitCode, err := run(vars, &args) @@ -368,7 +372,7 @@ func TestExitOnError(t *testing.T) { args.InputFile = []string{"testdata/bad.sql"} - vars = sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars = sqlcmd.InitializeVariables(args.useEnvVars()) setVars(vars, &args) exitCode, err = run(vars, &args) @@ -390,7 +394,7 @@ func TestAzureAuth(t *testing.T) { args.Query = "SELECT 'AZURE'" args.OutputFile = o.Name() args.UseAad = true - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") setVars(vars, &args) exitCode, err := run(vars, &args) @@ -410,7 +414,7 @@ func TestMissingInputFile(t *testing.T) { args.UseAad = true } - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) setVars(vars, &args) exitCode, err := run(vars, &args) @@ -453,10 +457,11 @@ func TestConditionsForPasswordPrompt(t *testing.T) { for _, testcase := range tests { t.Log(testcase.authenticationMethod, testcase.inputFile, testcase.username, testcase.pwd, testcase.expectedResult) args := newArguments() - args.DisableCmdAndWarn = true + args.DisableCmd = new(int) + *args.DisableCmd = 1 args.InputFile = testcase.inputFile args.UserName = testcase.username - vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn) + vars := sqlcmd.InitializeVariables(args.useEnvVars()) setVars(vars, &args) var connectConfig sqlcmd.ConnectSettings setConnect(&connectConfig, &args, vars) @@ -509,6 +514,11 @@ func TestConvertOsArgs(t *testing.T) { []string{"-E", "-v", "a=b", "x=y", "-i", "a.sql", "b.sql", "-v", "f=g", "-i", "c.sql", "-C", "-v", "ab=cd", "ef=hi"}, []string{"-E", "-v", "a=b", "-v", "x=y", "-i", "a.sql", "-i", "b.sql", "-v", "f=g", "-i", "c.sql", "-C", "-v", "ab=cd", "-v", "ef=hi"}, }, + { + "Flags with optional arguments", + []string{"-r", "1", "-X", "-k", "-C"}, + []string{"-r", "1", "-X", "0", "-k", "0", "-C"}, + }, } for _, c := range tests { t.Run(c.name, func(t *testing.T) { diff --git a/internal/sql/mssql.go b/internal/sql/mssql.go index a1c09604..e91b1c05 100644 --- a/internal/sql/mssql.go +++ b/internal/sql/mssql.go @@ -32,7 +32,7 @@ func (m *mssql) Connect( m.console = nil } m.sqlcmd = sqlcmd.New(m.console, "", v) - m.sqlcmd.Format = sqlcmd.NewSQLCmdDefaultFormatter(false) + m.sqlcmd.Format = sqlcmd.NewSQLCmdDefaultFormatter(false, sqlcmd.ControlIgnore) connect := sqlcmd.ConnectSettings{ ServerName: fmt.Sprintf( "%s,%d", diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 9aa04bc3..a81ae8c7 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -188,7 +188,7 @@ func TestListCommandUsesColorizer(t *testing.T) { func TestListColorPrintsStyleSamples(t *testing.T) { vars := InitializeVariables(false) s := New(nil, "", vars) - s.Format = NewSQLCmdDefaultFormatter(false) + s.Format = NewSQLCmdDefaultFormatter(false, ControlIgnore) // force colorizer on s.colorizer = color.New(true) buf := &memoryBuffer{buf: new(bytes.Buffer)} diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index a03e9b9b..4b5b1241 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -86,11 +86,12 @@ type sqlCmdFormatterType struct { } // NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter -func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool) Formatter { +func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool, ccb ControlCharacterBehavior) Formatter { return &sqlCmdFormatterType{ removeTrailingSpaces: removeTrailingSpaces, format: "horizontal", colorizer: color.New(false), + ccb: ccb, } } diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 6af54f6a..c89161de 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -586,7 +586,7 @@ func setupSqlCmdWithMemoryOutput(t testing.TB) (*Sqlcmd, *memoryBuffer) { v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true) + s.Format = NewSQLCmdDefaultFormatter(true, ControlIgnore) buf := &memoryBuffer{buf: new(bytes.Buffer)} s.SetOutput(buf) err := s.ConnectDb(nil, true) @@ -600,7 +600,7 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true) + s.Format = NewSQLCmdDefaultFormatter(true, ControlIgnore) file, err := os.CreateTemp("", "sqlcmdout") assert.NoError(t, err, "os.CreateTemp") s.SetOutput(file) @@ -618,7 +618,7 @@ func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) v.Set(SQLCMDMAXVARTYPEWIDTH, "0") s := New(nil, "", v) s.Connect = newConnect(t) - s.Format = NewSQLCmdDefaultFormatter(true) + s.Format = NewSQLCmdDefaultFormatter(true, ControlIgnore) outfile, err := os.CreateTemp("", "sqlcmdout") assert.NoError(t, err, "os.CreateTemp") errfile, err := os.CreateTemp("", "sqlcmderr") From 2bc62b225f38b35555dc5bb4572b3b2f5db87083 Mon Sep 17 00:00:00 2001 From: davidshi Date: Tue, 8 Aug 2023 16:56:15 -0500 Subject: [PATCH 2/3] Feat: add -e --- cmd/sqlcmd/sqlcmd.go | 4 +++- cmd/sqlcmd/sqlcmd_test.go | 4 ++-- pkg/sqlcmd/commands.go | 6 ++++++ pkg/sqlcmd/commands_test.go | 11 +++++++++++ pkg/sqlcmd/sqlcmd.go | 6 ++++-- 5 files changed, 26 insertions(+), 5 deletions(-) diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 7c8958c5..a390fa51 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -68,6 +68,7 @@ type SQLCmdArguments struct { DedicatedAdminConnection bool ListServers string RemoveControlCharacters *int + EchoInput bool // Keep Help at the end of the list Help bool } @@ -416,6 +417,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { _ = rootCmd.Flags().BoolP("enable-quoted-identifiers", "I", true, localizer.Sprintf("Provided for backward compatibility. Quoted identifiers are always enabled")) _ = rootCmd.Flags().BoolP("client-regional-setting", "R", false, localizer.Sprintf("Provided for backward compatibility. Client regional settings are not used")) _ = rootCmd.Flags().IntP(removeControlCharacters, "k", 0, localizer.Sprintf("%s Remove control characters from output. Pass 1 to substitute a space per character, 2 for a space per consecutive characters", "-k [1|2]")) + rootCmd.Flags().BoolVarP(&args.EchoInput, "echo-input", "e", false, localizer.Sprintf("Echo input")) } func setScriptVariable(v string) string { @@ -704,7 +706,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { if args.DisableCmd != nil { s.Cmd.DisableSysCommands(args.errorOnBlockedCmd()) } - + s.EchoInput = args.EchoInput if args.BatchTerminator != "GO" { err = s.Cmd.SetBatchTerminator(args.BatchTerminator) if err != nil { diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index a38b7f65..dbc9b2da 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -84,8 +84,8 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-s", "|", "-w", "10", "-W"}, func(args SQLCmdArguments) bool { return args.TrimSpaces && args.ColumnSeparator == "|" && *args.ScreenWidth == 10 }}, - {[]string{"-y", "100", "-Y", "200", "-P", "placeholder"}, func(args SQLCmdArguments) bool { - return *args.FixedTypeWidth == 200 && *args.VariableTypeWidth == 100 && args.Password == "placeholder" + {[]string{"-y", "100", "-Y", "200", "-P", "placeholder", "-e"}, func(args SQLCmdArguments) bool { + return *args.FixedTypeWidth == 200 && *args.VariableTypeWidth == 100 && args.Password == "placeholder" && args.EchoInput }}, {[]string{"-E", "-v", "a=b", "x=y", "-i", "a.sql", "b.sql", "-v", "f=g", "-i", "c.sql", "-C", "-v", "ab=cd", "ef=hi"}, func(args SQLCmdArguments) bool { return args.UseTrustedConnection && args.Variables["x"] == "y" && len(args.InputFile) == 3 && args.InputFile[0] == "a.sql" && args.TrustServerCertificate diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 8d8598e5..b0af2ab6 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -232,6 +232,12 @@ func goCommand(s *Sqlcmd, args []string, line uint) error { if err != nil || n < 1 { return InvalidCommandError("GO", line) } + if s.EchoInput { + err = listCommand(s, []string{}, line) + } + if err != nil { + return InvalidCommandError("GO", line) + } query := s.batch.String() if query == "" { return nil diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index a81ae8c7..eff3a509 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -364,3 +364,14 @@ func TestEditCommand(t *testing.T) { assert.Equal(t, "1> select 5000"+SqlcmdEol+"5000"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output from query after :ed command") } } + +func TestEchoInput(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + s.EchoInput = true + defer buf.Close() + c := []string{"set nocount on", "select 100", "go"} + err := runSqlCmd(t, s, c) + if assert.NoError(t, err, "go should not raise error") { + assert.Equal(t, "set nocount on"+SqlcmdEol+"select 100"+SqlcmdEol+"100"+SqlcmdEol+SqlcmdEol, buf.buf.String(), "Incorrect output with echo true") + } +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 051cd6bc..169eac91 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -79,8 +79,10 @@ type Sqlcmd struct { PrintError func(msg string, severity uint8) bool // UnicodeOutputFile is true when UTF16 file output is needed UnicodeOutputFile bool - colorizer color.Colorizer - termchan chan os.Signal + // EchoInput tells the GO command to print the batch text before running the query + EchoInput bool + colorizer color.Colorizer + termchan chan os.Signal } // New creates a new Sqlcmd instance. From a2c973c40f13b0f71099d74c7756f81ccf0786ed Mon Sep 17 00:00:00 2001 From: davidshi Date: Wed, 9 Aug 2023 13:32:45 -0500 Subject: [PATCH 3/3] feat: add -t --- cmd/sqlcmd/sqlcmd.go | 6 +++++- cmd/sqlcmd/sqlcmd_test.go | 5 +++-- pkg/sqlcmd/connect.go | 2 +- pkg/sqlcmd/format.go | 5 +++++ pkg/sqlcmd/sqlcmd.go | 7 +++++++ pkg/sqlcmd/sqlcmd_test.go | 12 ++++++++++++ pkg/sqlcmd/variables.go | 5 +++++ 7 files changed, 38 insertions(+), 4 deletions(-) diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index a390fa51..11c1887d 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -69,6 +69,7 @@ type SQLCmdArguments struct { ListServers string RemoveControlCharacters *int EchoInput bool + QueryTimeout int // Keep Help at the end of the list Help bool } @@ -144,6 +145,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { err = rangeParameterError("-Y", fmt.Sprint(*a.FixedTypeWidth), 0, 8000, true) case a.VariableTypeWidth != nil && (*a.VariableTypeWidth < 0 || *a.VariableTypeWidth > 8000): err = rangeParameterError("-y", fmt.Sprint(*a.VariableTypeWidth), 0, 8000, true) + case a.QueryTimeout < 0 || a.QueryTimeout > 65534: + err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true) } } if err != nil { @@ -418,6 +421,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { _ = rootCmd.Flags().BoolP("client-regional-setting", "R", false, localizer.Sprintf("Provided for backward compatibility. Client regional settings are not used")) _ = rootCmd.Flags().IntP(removeControlCharacters, "k", 0, localizer.Sprintf("%s Remove control characters from output. Pass 1 to substitute a space per character, 2 for a space per consecutive characters", "-k [1|2]")) rootCmd.Flags().BoolVarP(&args.EchoInput, "echo-input", "e", false, localizer.Sprintf("Echo input")) + rootCmd.Flags().IntVarP(&args.QueryTimeout, "query-timeout", "t", 0, "Query timeout") } func setScriptVariable(v string) string { @@ -602,7 +606,7 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) { return "" }, sqlcmd.SQLCMDUSER: func(a *SQLCmdArguments) string { return a.UserName }, - sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return "" }, + sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return fmt.Sprint(a.QueryTimeout) }, sqlcmd.SQLCMDHEADERS: func(a *SQLCmdArguments) string { return fmt.Sprint(a.Headers) }, sqlcmd.SQLCMDCOLSEP: func(a *SQLCmdArguments) string { if a.ColumnSeparator != "" { diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index dbc9b2da..df5c6cad 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -81,8 +81,8 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-w", "10"}, func(args SQLCmdArguments) bool { return args.ScreenWidth != nil && *args.ScreenWidth == 10 && args.FixedTypeWidth == nil && args.VariableTypeWidth == nil }}, - {[]string{"-s", "|", "-w", "10", "-W"}, func(args SQLCmdArguments) bool { - return args.TrimSpaces && args.ColumnSeparator == "|" && *args.ScreenWidth == 10 + {[]string{"-s", "|", "-w", "10", "-W", "-t", "10"}, func(args SQLCmdArguments) bool { + return args.TrimSpaces && args.ColumnSeparator == "|" && *args.ScreenWidth == 10 && args.QueryTimeout == 10 }}, {[]string{"-y", "100", "-Y", "200", "-P", "placeholder", "-e"}, func(args SQLCmdArguments) bool { return *args.FixedTypeWidth == 200 && *args.VariableTypeWidth == 100 && args.Password == "placeholder" && args.EchoInput @@ -149,6 +149,7 @@ func TestInvalidCommandLine(t *testing.T) { {[]string{"-Y", "-2"}, "'-Y -2': value must be greater than or equal to 0 and less than or equal to 8000."}, {[]string{"-P"}, "'-P': Missing argument. Enter '-?' for help."}, {[]string{"-;"}, "';': Unknown Option. Enter '-?' for help."}, + {[]string{"-t", "-2"}, "'-t -2': value must be greater than or equal to 0 and less than or equal to 65534."}, } for _, test := range commands { diff --git a/pkg/sqlcmd/connect.go b/pkg/sqlcmd/connect.go index e445cfdc..1ef5ae8a 100644 --- a/pkg/sqlcmd/connect.go +++ b/pkg/sqlcmd/connect.go @@ -11,7 +11,7 @@ import ( "github.com/microsoft/go-mssqldb/azuread" ) -// ConnectSettings specifies the settings for connections +// ConnectSettings specifies the settings for SQL connections and queries type ConnectSettings struct { // ServerName is the full name including instance and port ServerName string diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 4b5b1241..55bd2e25 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -4,7 +4,9 @@ package sqlcmd import ( + "context" "database/sql" + "errors" "fmt" "io" "strings" @@ -214,6 +216,9 @@ func (f *sqlCmdFormatterType) AddMessage(msg string) { func (f *sqlCmdFormatterType) AddError(err error) { print := true b := new(strings.Builder) + if errors.Is(err, context.DeadlineExceeded) { + err = localizer.Errorf("Timeout expired") + } msg := err.Error() switch e := (err).(type) { case mssql.Error: diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 169eac91..277ae42a 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -17,6 +17,7 @@ import ( "sort" "strings" "syscall" + "time" "github.com/golang-sql/sqlexp" mssql "github.com/microsoft/go-mssqldb" @@ -415,6 +416,12 @@ func (s *Sqlcmd) runQuery(query string) (int, error) { retcode := -101 s.Format.BeginBatch(query, s.vars, s.GetOutput(), s.GetError()) ctx := context.Background() + timeout := s.vars.QueryTimeoutSeconds() + if timeout > 0 { + ct, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + ctx = ct + } retmsg := &sqlexp.ReturnMessage{} rows, qe := s.db.QueryContext(ctx, query, retmsg) if qe != nil { diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index c89161de..fbc5c915 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -551,6 +551,7 @@ func TestSqlCmdOutputAndError(t *testing.T) { func TestVeryLongLineInFile(t *testing.T) { s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() val := strings.Repeat("a1b", (3*1024*1024)/3) line := "set nocount on" + SqlcmdEol + "select('" + val + "')" file, err := os.CreateTemp("", "sqlcmdlongline") @@ -565,6 +566,17 @@ func TestVeryLongLineInFile(t *testing.T) { } } +func TestQueryTimeout(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.vars.Set(SQLCMDSTATTIMEOUT, "1") + i, err := s.runQuery("waitfor delay '00:00:10'") + if assert.NoError(t, err, "runQuery returned an error") { + assert.Equal(t, -100, i, "return from runQuery") + assert.Equal(t, "Timeout expired"+SqlcmdEol, buf.buf.String(), "Query should have timed out") + } +} + // runSqlCmd uses lines as input for sqlcmd instead of relying on file or console input func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error { t.Helper() diff --git a/pkg/sqlcmd/variables.go b/pkg/sqlcmd/variables.go index 98b40b10..aa601627 100644 --- a/pkg/sqlcmd/variables.go +++ b/pkg/sqlcmd/variables.go @@ -198,6 +198,11 @@ func (v Variables) ColorScheme() string { return v[SQLCMDCOLORSCHEME] } +// QueryTimeoutSeconds limits the allowed time for a query to complete. Any value <= 0 specifies unlimited +func (v Variables) QueryTimeoutSeconds() int64 { + return mustValue(v[SQLCMDSTATTIMEOUT]) +} + func mustValue(val string) int64 { var n int64 _, err := fmt.Sscanf(val, "%d", &n)