Skip to content

Feat: Add -e and -k, finish -X, -r, -L #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ The `sqlcmd` project aims to be a complete port of the original ODBC sqlcmd to t
- There are new posix-style versions of each flag, such as `--input-file` for `-i`. `sqlcmd -?` will print those parameter names. Those new names do not preserve backward compatibility with ODBC `sqlcmd`. For example, to specify multiple input file names using `--input-file`, the file names must be comma-delimited, not space-delimited.

The following switches have different behavior in this version of `sqlcmd` compared to the original ODBC based `sqlcmd`.
- `-r` requires a 0 or 1 argument
- `-R` switch is ignored. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms.
- `-I` switch is ignored; quoted identifiers are always set on. To disable quoted identifier behavior, add `SET QUOTED IDENTIFIER OFF` in your scripts.
- `-N` now takes a string value that can be one of `true`, `false`, or `disable` to specify the encryption choice.
Expand All @@ -141,7 +140,6 @@ The following switches have different behavior in this version of `sqlcmd` compa
- If using a single `-i` flag to pass multiple file names, there must be a space after the `-i`. Example: `-i file1.sql file2.sql`
- `-M` switch is ignored. Sqlcmd always enables multi-subnet failover.


### Switches not available in the new sqlcmd (go-sqlcmd) yet

There are a few switches yet to be implemented in the new `sqlcmd` (go-sqlcmd) compared
Expand Down
170 changes: 131 additions & 39 deletions cmd/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ type SQLCmdArguments struct {
// Query to run then exit
Query string
Server string
// Disable syscommands with a warning
DisableCmdAndWarn bool
// Disable syscommands with a warning or error
DisableCmd *int
// AuthenticationMethod is new for go-sqlcmd
AuthenticationMethod string
UseAad bool
Expand All @@ -55,7 +55,7 @@ type SQLCmdArguments struct {
ErrorSeverityLevel uint8
ErrorLevel int
Format string
ErrorsToStderr int
ErrorsToStderr *int
Headers int
UnicodeOutputFile bool
Version bool
Expand All @@ -66,25 +66,60 @@ type SQLCmdArguments struct {
TrimSpaces bool
Password string
DedicatedAdminConnection bool
ListServers bool
ListServers string
RemoveControlCharacters *int
EchoInput bool
QueryTimeout int
// Keep Help at the end of the list
Help bool
}

func (args *SQLCmdArguments) useEnvVars() bool {
return args.DisableCmd == nil
}

func (args *SQLCmdArguments) errorOnBlockedCmd() bool {
return args.DisableCmd != nil && *args.DisableCmd > 0
}

func (args *SQLCmdArguments) warnOnBlockedCmd() bool {
return args.DisableCmd != nil && *args.DisableCmd <= 0
}

func (args *SQLCmdArguments) runStartupScript() bool {
return args.DisableCmd == nil
}

func (args *SQLCmdArguments) getControlCharacterBehavior() sqlcmd.ControlCharacterBehavior {
if args.RemoveControlCharacters == nil {
return sqlcmd.ControlIgnore
}
switch *args.RemoveControlCharacters {
case 1:
return sqlcmd.ControlReplace
case 2:
return sqlcmd.ControlReplaceConsecutive
}
return sqlcmd.ControlRemove
}

const (
sqlcmdErrorPrefix = "Sqlcmd: "
applicationIntent = "application-intent"
errorsToStderr = "errors-to-stderr"
format = "format"
encryptConnection = "encrypt-connection"
screenWidth = "screen-width"
fixedTypeWidth = "fixed-type-width"
variableTypeWidth = "variable-type-width"
sqlcmdErrorPrefix = "Sqlcmd: "
applicationIntent = "application-intent"
errorsToStderr = "errors-to-stderr"
format = "format"
encryptConnection = "encrypt-connection"
screenWidth = "screen-width"
fixedTypeWidth = "fixed-type-width"
variableTypeWidth = "variable-type-width"
disableCmdAndWarn = "disable-cmd-and-warn"
listServers = "list-servers"
removeControlCharacters = "remove-control-characters"
)

// Validate arguments for settings not describe
func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) {
if a.ListServers {
if a.ListServers != "" {
c.Flags().Visit(func(f *pflag.Flag) {
if f.Shorthand != "L" {
err = localizer.Errorf("The -L parameter can not be used in combination with other parameters.")
Expand All @@ -110,6 +145,8 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) {
err = rangeParameterError("-Y", fmt.Sprint(*a.FixedTypeWidth), 0, 8000, true)
case a.VariableTypeWidth != nil && (*a.VariableTypeWidth < 0 || *a.VariableTypeWidth > 8000):
err = rangeParameterError("-y", fmt.Sprint(*a.VariableTypeWidth), 0, 8000, true)
case a.QueryTimeout < 0 || a.QueryTimeout > 65534:
err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true)
}
}
if err != nil {
Expand Down Expand Up @@ -170,9 +207,11 @@ func Execute(version string) {
},
Run: func(cmd *cobra.Command, argss []string) {
// emulate -L returning no servers
if args.ListServers {
fmt.Println()
fmt.Println(localizer.Sprintf("Servers:"))
if args.ListServers != "" {
if args.ListServers != "c" {
fmt.Println()
fmt.Println(localizer.Sprintf("Servers:"))
}
fmt.Println(" ;UID:Login ID=?;PWD:Password=?;Trusted_Connection:Use Integrated Security=?;*APP:AppName=?;*WSID:WorkStation ID=?;")
os.Exit(0)
}
Expand All @@ -181,7 +220,7 @@ func Execute(version string) {
os.Exit(1)
}

vars := sqlcmd.InitializeVariables(!args.DisableCmdAndWarn)
vars := sqlcmd.InitializeVariables(args.useEnvVars())
setVars(vars, &args)

if args.Version {
Expand Down Expand Up @@ -225,10 +264,11 @@ func Execute(version string) {
}

// We need to rewrite the arguments to add -i and -v in front of each space-delimited value to be Cobra-friendly.
// For flags like -r we need to inject the default value if the user omits it
func convertOsArgs(args []string) (cargs []string) {
flag := ""
first := true
for _, a := range args {
for i, a := range args {
if flag != "" {
// If the user has a file named "-i" the only way they can pass it on the command line
// is with triple quotes: sqlcmd -i """-i""" which will convince the flags parser to
Expand All @@ -240,11 +280,34 @@ func convertOsArgs(args []string) (cargs []string) {
}
first = false
}
var defValue string
if isListFlag(a) {
flag = a
first = true
} else {
defValue = checkDefaultValue(args, i)
}
cargs = append(cargs, a)
if defValue != "" {
cargs = append(cargs, defValue)
}
}
return
}

// If args[i] is the given flag and args[i+1] is another flag, returns the value to append after the flag
func checkDefaultValue(args []string, i int) (val string) {
flags := map[rune]string{
'r': "0",
'k': "0",
'L': "|", // | is the sentinel for no value since users are unlikely to use it. It's "reserved" in most shells
'X': "0",
}
if isFlag(args[i]) && (len(args) == i+1 || args[i+1][0] == '-') {
if v, ok := flags[rune(args[i][1])]; ok {
val = v
return
}
}
return
}
Expand Down Expand Up @@ -296,6 +359,9 @@ func SetScreenWidthFlags(args *SQLCmdArguments, rootCmd *cobra.Command) {
args.ScreenWidth = getOptionalIntArgument(rootCmd, screenWidth)
args.FixedTypeWidth = getOptionalIntArgument(rootCmd, fixedTypeWidth)
args.VariableTypeWidth = getOptionalIntArgument(rootCmd, variableTypeWidth)
args.DisableCmd = getOptionalIntArgument(rootCmd, disableCmdAndWarn)
args.ErrorsToStderr = getOptionalIntArgument(rootCmd, errorsToStderr)
args.RemoveControlCharacters = getOptionalIntArgument(rootCmd, removeControlCharacters)
}

func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
Expand All @@ -313,7 +379,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
rootCmd.Flags().StringVarP(&args.InitialQuery, "initial-query", "q", "", localizer.Sprintf("Executes a query when sqlcmd starts, but does not exit sqlcmd when the query has finished running. Multiple-semicolon-delimited queries can be executed"))
rootCmd.Flags().StringVarP(&args.Query, "query", "Q", "", localizer.Sprintf("Executes a query when sqlcmd starts and then immediately exits sqlcmd. Multiple-semicolon-delimited queries can be executed"))
rootCmd.Flags().StringVarP(&args.Server, "server", "S", "", localizer.Sprintf("%s Specifies the instance of SQL Server to which to connect. It sets the sqlcmd scripting variable %s.", localizer.ConnStrPattern, localizer.ServerEnvVar))
rootCmd.Flags().BoolVarP(&args.DisableCmdAndWarn, "disable-cmd-and-warn", "X", false, localizer.Sprintf("Disables commands that might compromise system security. Sqlcmd issues a warning and continues"))
_ = rootCmd.Flags().IntP(disableCmdAndWarn, "X", 0, localizer.Sprintf("%s Disables commands that might compromise system security. Passing 1 tells sqlcmd to exit when disabled commands are run.", "-X[1]"))
rootCmd.Flags().StringVar(&args.AuthenticationMethod, "authentication-method", "", localizer.Sprintf("Specifies the SQL authentication method to use to connect to Azure SQL Database. One of: ActiveDirectoryDefault, ActiveDirectoryIntegrated, ActiveDirectoryPassword, ActiveDirectoryInteractive, ActiveDirectoryManagedIdentity, ActiveDirectoryServicePrincipal, SqlPassword"))
rootCmd.Flags().BoolVarP(&args.UseAad, "use-aad", "G", false, localizer.Sprintf("Tells sqlcmd to use ActiveDirectory authentication. If no user name is provided, authentication method ActiveDirectoryDefault is used. If a password is provided, ActiveDirectoryPassword is used. Otherwise ActiveDirectoryInteractive is used"))
rootCmd.Flags().BoolVarP(&args.DisableVariableSubstitution, "disable-variable-substitution", "x", false, localizer.Sprintf("Causes sqlcmd to ignore scripting variables. This parameter is useful when a script contains many %s statements that may contain strings that have the same format as regular variables, such as $(variable_name)", localizer.InsertKeyword))
Expand All @@ -328,8 +394,7 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
// Can't use NoOptDefVal until this fix: https://github.com/spf13/cobra/issues/866
//rootCmd.Flags().Lookup(encryptConnection).NoOptDefVal = "true"
rootCmd.Flags().StringVarP(&args.Format, format, "F", "horiz", localizer.Sprintf("Specifies the formatting for results"))
rootCmd.Flags().IntVarP(&args.ErrorsToStderr, errorsToStderr, "r", -1, localizer.Sprintf("Controls which error messages are sent to stdout. Messages that have severity level greater than or equal to this level are sent"))
//rootCmd.Flags().Lookup(errorsToStderr).NoOptDefVal = "0"
_ = rootCmd.Flags().IntP(errorsToStderr, "r", -1, localizer.Sprintf("%s Redirects error messages with severity >= 11 output to stderr. Pass 1 to to redirect all errors including PRINT.", "-r[0 | 1]"))
rootCmd.Flags().IntVar(&args.DriverLoggingLevel, "driver-logging-level", 0, localizer.Sprintf("Level of mssql driver messages to print"))
rootCmd.Flags().BoolVarP(&args.ExitOnError, "exit-on-error", "b", false, localizer.Sprintf("Specifies that sqlcmd exits and returns a %s value when an error occurs", localizer.DosErrorLevel))
rootCmd.Flags().IntVarP(&args.ErrorLevel, "error-level", "m", 0, localizer.Sprintf("Controls which error messages are sent to %s. Messages that have severity level greater than or equal to this level are sent", localizer.StdoutName))
Expand All @@ -350,10 +415,13 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
_ = rootCmd.Flags().IntP(screenWidth, "w", 0, localizer.Sprintf("Specifies the screen width for output"))
_ = rootCmd.Flags().IntP(variableTypeWidth, "y", 256, setScriptVariable("SQLCMDMAXVARTYPEWIDTH"))
_ = rootCmd.Flags().IntP(fixedTypeWidth, "Y", 0, setScriptVariable("SQLCMDMAXFIXEDTYPEWIDTH"))
rootCmd.Flags().BoolVarP(&args.ListServers, "list-servers", "L", false, "List servers")
rootCmd.Flags().StringVarP(&args.ListServers, listServers, "L", "", localizer.Sprintf("%s List servers. Pass %s to omit 'Servers:' output.", "-L[c]", "c"))
rootCmd.Flags().BoolVarP(&args.DedicatedAdminConnection, "dedicated-admin-connection", "A", false, localizer.Sprintf("Dedicated administrator connection"))
_ = rootCmd.Flags().BoolP("enable-quoted-identifiers", "I", true, localizer.Sprintf("Provided for backward compatibility. Quoted identifiers are always enabled"))
_ = rootCmd.Flags().BoolP("client-regional-setting", "R", false, localizer.Sprintf("Provided for backward compatibility. Client regional settings are not used"))
_ = rootCmd.Flags().IntP(removeControlCharacters, "k", 0, localizer.Sprintf("%s Remove control characters from output. Pass 1 to substitute a space per character, 2 for a space per consecutive characters", "-k [1|2]"))
rootCmd.Flags().BoolVarP(&args.EchoInput, "echo-input", "e", false, localizer.Sprintf("Echo input"))
rootCmd.Flags().IntVarP(&args.QueryTimeout, "query-timeout", "t", 0, "Query timeout")
}

func setScriptVariable(v string) string {
Expand Down Expand Up @@ -403,7 +471,32 @@ func normalizeFlags(cmd *cobra.Command) error {
err = invalidParameterError("-r", v, "0", "1")
return pflag.NormalizedName("")
}
case disableCmdAndWarn:
switch v {
case "0", "1":
return pflag.NormalizedName(name)
default:
err = invalidParameterError("-X", v, "1")
return pflag.NormalizedName("")
}
case listServers:
switch v {
case "|", "c":
return pflag.NormalizedName(name)
default:
err = invalidParameterError("-L", v, "c")
return pflag.NormalizedName("")
}
case removeControlCharacters:
switch v {
case "0", "1", "2":
return pflag.NormalizedName(name)
default:
err = invalidParameterError("-k", v, "1", "2")
return pflag.NormalizedName("")
}
}

return pflag.NormalizedName(name)
})
if err != nil {
Expand Down Expand Up @@ -513,7 +606,7 @@ func setVars(vars *sqlcmd.Variables, args *SQLCmdArguments) {
return ""
},
sqlcmd.SQLCMDUSER: func(a *SQLCmdArguments) string { return a.UserName },
sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return "" },
sqlcmd.SQLCMDSTATTIMEOUT: func(a *SQLCmdArguments) string { return fmt.Sprint(a.QueryTimeout) },
sqlcmd.SQLCMDHEADERS: func(a *SQLCmdArguments) string { return fmt.Sprint(a.Headers) },
sqlcmd.SQLCMDCOLSEP: func(a *SQLCmdArguments) string {
if a.ColumnSeparator != "" {
Expand Down Expand Up @@ -558,7 +651,7 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
connect.ApplicationName = "sqlcmd"
if len(args.Password) > 0 {
connect.Password = args.Password
} else if !args.DisableCmdAndWarn {
} else if args.useEnvVars() {
connect.Password = os.Getenv(sqlcmd.SQLCMDPASSWORD)
}
connect.ServerName = args.Server
Expand All @@ -576,7 +669,7 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
connect.UseTrustedConnection = args.UseTrustedConnection
connect.TrustServerCertificate = args.TrustServerCertificate
connect.AuthenticationMethod = args.authenticationMethod(connect.Password != "")
connect.DisableEnvironmentVariables = args.DisableCmdAndWarn
connect.DisableEnvironmentVariables = !args.useEnvVars()
connect.DisableVariableSubstitution = args.DisableVariableSubstitution
connect.ApplicationIntent = args.ApplicationIntent
connect.LoginTimeoutSeconds = args.LoginTimeout
Expand Down Expand Up @@ -614,10 +707,10 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
defer s.StopCloseHandler()
s.UnicodeOutputFile = args.UnicodeOutputFile

if args.DisableCmdAndWarn {
s.Cmd.DisableSysCommands(false)
if args.DisableCmd != nil {
s.Cmd.DisableSysCommands(args.errorOnBlockedCmd())
}

s.EchoInput = args.EchoInput
if args.BatchTerminator != "GO" {
err = s.Cmd.SetBatchTerminator(args.BatchTerminator)
if err != nil {
Expand All @@ -629,25 +722,24 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
}

s.Connect = &connectConfig
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces)
s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior())
if args.OutputFile != "" {
err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile})
if err != nil {
return 1, err
}
} else {
} else if args.ErrorsToStderr != nil {
var stderrSeverity uint8 = 11
if args.ErrorsToStderr == 1 {
if *args.ErrorsToStderr == 1 {
stderrSeverity = 0
}
if args.ErrorsToStderr >= 0 {
s.PrintError = func(msg string, severity uint8) bool {
if severity >= stderrSeverity {
s.WriteError(os.Stderr, errors.New(msg+sqlcmd.SqlcmdEol))
return true
}
return false

s.PrintError = func(msg string, severity uint8) bool {
if severity >= stderrSeverity {
s.WriteError(os.Stderr, errors.New(msg+sqlcmd.SqlcmdEol))
return true
}
return false
}
}

Expand All @@ -659,7 +751,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
}

script := vars.StartupScriptFile()
if !args.DisableCmdAndWarn && len(script) > 0 {
if args.runStartupScript() && len(script) > 0 {
f, fileErr := os.Open(script)
if fileErr != nil {
s.WriteError(s.GetError(), sqlcmd.InvalidVariableValue(sqlcmd.SQLCMDINI, script))
Expand Down
Loading