Skip to content

Commit aac17a5

Browse files
committedMay 18, 2024
Make error response customizable via env-vars in plugin config
Return empty response if chosen by the user Make log level of audit trail customizable Update gatewayd_plugin.yaml to reflect changes Rename function to prepareRequest
1 parent dc2501e commit aac17a5

File tree

6 files changed

+76
-32
lines changed

6 files changed

+76
-32
lines changed
 

‎gatewayd_plugin.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,17 @@ plugins:
4747
# False (strict): The plugin will block the request if it detects an SQL injection attack.
4848
# This greatly increases the false positive rate.
4949
- LIBINJECTION_PERMISSIVE_MODE=True
50+
# The following env-vars are used to configure the plugin's response.
51+
# Possiblel values: error, empty or terminate
52+
- RESPONSE_TYPE=error
53+
# Possible values: DEBUG, LOG, INFO, NOTICE, WARNING, and EXCEPTION
54+
- ERROR_SEVERITY=EXCEPTION
55+
# Ref: https://www.postgresql.org/docs/current/errcodes-appendix.html
56+
- ERROR_NUMBER=42000
57+
- ERROR_MESSAGE=SQL injection detected
58+
- ERROR_DETAIL=Back off, you're not welcome here.
59+
# Possible values: trace, debug, info, warn, error
60+
# Other values will result in no level being set.
61+
- LOG_LEVEL=error
5062
# Checksum hash to verify the binary before loading
5163
checksum: dee4aa014a722e1865d91744a4fd310772152467d9c6ab4ba17fd9dd40d3f724

‎main.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ func main() {
4646
pluginInstance.Impl.ServingAPIAddress = cast.ToString(cfg["servingAPIAddress"])
4747
pluginInstance.Impl.ModelName = cast.ToString(cfg["modelName"])
4848
pluginInstance.Impl.ModelVersion = cast.ToString(cfg["modelVersion"])
49+
50+
pluginInstance.Impl.ResponseType = cast.ToString(cfg["responseType"])
51+
pluginInstance.Impl.ErrorMessage = cast.ToString(cfg["errorMessage"])
52+
pluginInstance.Impl.ErrorSeverity = cast.ToString(cfg["errorSeverity"])
53+
pluginInstance.Impl.ErrorNumber = cast.ToString(cfg["errorNumber"])
54+
pluginInstance.Impl.ErrorDetail = cast.ToString(cfg["errorDetail"])
55+
pluginInstance.Impl.LogLevel = cast.ToString(cfg["logLevel"])
4956
}
5057

5158
goplugin.Serve(&goplugin.ServeConfig{

‎plugin/metrics.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ var (
2525
Name: "detections_total",
2626
Help: "The total number of malicious requests detected",
2727
}, []string{"detector"})
28-
Preventions = promauto.NewCounter(prometheus.CounterOpts{
28+
Preventions = promauto.NewCounterVec(prometheus.CounterOpts{
2929
Namespace: metrics.Namespace,
3030
Name: "preventions_total",
3131
Help: "The total number of malicious requests prevented",
32-
})
32+
}, []string{"response_type"})
3333
)

‎plugin/module.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,21 @@ var (
4545
"threshold": sdkConfig.GetEnv("THRESHOLD", "0.8"),
4646
"enableLibinjection": sdkConfig.GetEnv("ENABLE_LIBINJECTION", "true"),
4747
"libinjectionPermissiveMode": sdkConfig.GetEnv("LIBINJECTION_MODE", "true"),
48+
49+
// Possible values: error or empty
50+
"responseType": sdkConfig.GetEnv("RESPONSE_TYPE", ResponseType),
51+
52+
// This is part of the error response and the audit trail
53+
"errorMessage": sdkConfig.GetEnv("ERROR_MESSAGE", ErrorMessage),
54+
55+
// Response type: error
56+
// Possible severity values: DEBUG, LOG, INFO, NOTICE, WARNING, and EXCEPTION
57+
"errorSeverity": sdkConfig.GetEnv("ERROR_SEVERITY", ErrorSeverity),
58+
"errorNumber": sdkConfig.GetEnv("ERROR_NUMBER", ErrorNumber),
59+
"errorDetail": sdkConfig.GetEnv("ERROR_DETAIL", ErrorDetail),
60+
61+
// Log an audit trail
62+
"logLevel": sdkConfig.GetEnv("LOG_LEVEL", LogLevel),
4863
},
4964
"hooks": []interface{}{
5065
// Converting HookName to int32 is required because the plugin

‎plugin/plugin.go

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/hashicorp/go-hclog"
1616
goplugin "github.com/hashicorp/go-plugin"
1717
"github.com/jackc/pgx/v5/pgproto3"
18+
"github.com/prometheus/client_golang/prometheus"
1819
"github.com/spf13/cast"
1920
"google.golang.org/grpc"
2021
)
@@ -30,15 +31,17 @@ const (
3031
OutputsField string = "outputs"
3132
TokensField string = "tokens"
3233
StringField string = "String"
34+
ResponseTypeField string = "response_type"
3335

3436
DeepLearningModel string = "deep_learning_model"
3537
Libinjection string = "libinjection"
3638

37-
ErrorLevel string = "error"
38-
ExceptionLevel string = "EXCEPTION"
39-
ErrorNumber string = "42000"
40-
DetectionMessage string = "SQL injection detected"
41-
ErrorResponseMessage string = "Back off, you're not welcome here."
39+
ResponseType string = "error"
40+
ErrorSeverity string = "EXCEPTION"
41+
ErrorNumber string = "42000"
42+
ErrorMessage string = "SQL injection detected"
43+
ErrorDetail string = "Back off, you're not welcome here."
44+
LogLevel string = "error"
4245

4346
TokenizeAndSequencePath string = "/tokenize_and_sequence"
4447
PredictPath string = "/v1/models/%s/versions/%s:predict"
@@ -55,6 +58,12 @@ type Plugin struct {
5558
ServingAPIAddress string
5659
ModelName string
5760
ModelVersion string
61+
ResponseType string
62+
ErrorMessage string
63+
ErrorSeverity string
64+
ErrorNumber string
65+
ErrorDetail string
66+
LogLevel string
5867
}
5968

6069
type InjectionDetectionPlugin struct {
@@ -139,7 +148,7 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
139148
if err != nil {
140149
p.Logger.Error("Failed to make POST request", ErrorField, err)
141150
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
142-
return p.errorResponse(
151+
return p.prepareResponse(
143152
req,
144153
map[string]any{
145154
QueryField: queryString,
@@ -163,7 +172,7 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
163172
if err != nil {
164173
p.Logger.Error("Failed to make POST request", ErrorField, err)
165174
if p.isSQLi(queryString) && !p.LibinjectionPermissiveMode {
166-
return p.errorResponse(
175+
return p.prepareResponse(
167176
req,
168177
map[string]any{
169178
QueryField: queryString,
@@ -189,8 +198,8 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
189198
}
190199

191200
Detections.With(map[string]string{DetectorField: DeepLearningModel}).Inc()
192-
p.Logger.Warn(DetectionMessage, ScoreField, score, DetectorField, DeepLearningModel)
193-
return p.errorResponse(
201+
p.Logger.Warn(p.ErrorMessage, ScoreField, score, DetectorField, DeepLearningModel)
202+
return p.prepareResponse(
194203
req,
195204
map[string]any{
196205
QueryField: queryString,
@@ -200,8 +209,8 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
200209
), nil
201210
} else if p.EnableLibinjection && injection && !p.LibinjectionPermissiveMode {
202211
Detections.With(map[string]string{DetectorField: Libinjection}).Inc()
203-
p.Logger.Warn(DetectionMessage, DetectorField, Libinjection)
204-
return p.errorResponse(
212+
p.Logger.Warn(p.ErrorMessage, DetectorField, Libinjection)
213+
return p.prepareResponse(
205214
req,
206215
map[string]any{
207216
QueryField: queryString,
@@ -224,35 +233,36 @@ func (p *Plugin) isSQLi(query string) bool {
224233
// Check if the query is an SQL injection using libinjection.
225234
injection, _ := libinjection.IsSQLi(query)
226235
if injection {
227-
p.Logger.Warn(DetectionMessage, DetectorField, Libinjection)
236+
p.Logger.Warn(p.ErrorMessage, DetectorField, Libinjection)
228237
}
229238
p.Logger.Trace("SQLInjection", IsInjectionField, cast.ToString(injection))
230239
return injection
231240
}
232241

233-
func (p *Plugin) errorResponse(req *v1.Struct, fields map[string]any) *v1.Struct {
234-
Preventions.Inc()
242+
func (p *Plugin) prepareResponse(req *v1.Struct, fields map[string]any) *v1.Struct {
243+
Preventions.With(prometheus.Labels{ResponseTypeField: p.ResponseType}).Inc()
235244

236-
// Create a PostgreSQL error response.
237-
errResp := postgres.ErrorResponse(
238-
DetectionMessage,
239-
ExceptionLevel,
240-
ErrorNumber,
241-
ErrorResponseMessage,
242-
)
245+
var encapsulatedResponse []byte
243246

244-
// Create a ready for query response.
245-
readyForQuery := &pgproto3.ReadyForQuery{TxStatus: 'I'}
246-
// TODO: Decide whether to terminate the connection.
247-
response, err := readyForQuery.Encode(errResp)
248-
if err != nil {
249-
p.Logger.Error("Failed to encode ready for query response", ErrorField, err)
250-
return req
247+
if p.ResponseType == "error" {
248+
// Create a PostgreSQL error response.
249+
encapsulatedResponse = postgres.ErrorResponse(
250+
p.ErrorMessage,
251+
p.ErrorSeverity,
252+
ErrorNumber,
253+
ErrorDetail,
254+
)
255+
} else {
256+
// Create a PostgreSQL empty query response.
257+
encapsulatedResponse, _ = (&pgproto3.EmptyQueryResponse{}).Encode(nil)
251258
}
252259

260+
// Create and encode a ready for query response.
261+
response, _ := (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(encapsulatedResponse)
262+
253263
signals, err := v1.NewList([]any{
254264
sdkAct.Terminate().ToMap(),
255-
sdkAct.Log(ErrorLevel, DetectionMessage, fields).ToMap(),
265+
sdkAct.Log(p.LogLevel, p.ErrorMessage, fields).ToMap(),
256266
})
257267
if err != nil {
258268
p.Logger.Error("Failed to create signals", ErrorField, err)

‎plugin/plugin_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func Test_errorResponse(t *testing.T) {
5252
require.NoError(t, err)
5353
assert.NotNil(t, reqJSON)
5454

55-
resp := p.errorResponse(
55+
resp := p.prepareResponse(
5656
reqJSON,
5757
map[string]any{
5858
"score": 0.9999,

0 commit comments

Comments
 (0)