@@ -15,6 +15,7 @@ import (
15
15
"github.com/hashicorp/go-hclog"
16
16
goplugin "github.com/hashicorp/go-plugin"
17
17
"github.com/jackc/pgx/v5/pgproto3"
18
+ "github.com/prometheus/client_golang/prometheus"
18
19
"github.com/spf13/cast"
19
20
"google.golang.org/grpc"
20
21
)
@@ -30,15 +31,17 @@ const (
30
31
OutputsField string = "outputs"
31
32
TokensField string = "tokens"
32
33
StringField string = "String"
34
+ ResponseTypeField string = "response_type"
33
35
34
36
DeepLearningModel string = "deep_learning_model"
35
37
Libinjection string = "libinjection"
36
38
37
- ErrorLevel string = "error"
38
- ExceptionLevel string = "EXCEPTION"
39
- ErrorNumber string = "42000"
40
- DetectionMessage string = "SQL injection detected"
41
- ErrorResponseMessage string = "Back off, you're not welcome here."
39
+ ResponseType string = "error"
40
+ ErrorSeverity string = "EXCEPTION"
41
+ ErrorNumber string = "42000"
42
+ ErrorMessage string = "SQL injection detected"
43
+ ErrorDetail string = "Back off, you're not welcome here."
44
+ LogLevel string = "error"
42
45
43
46
TokenizeAndSequencePath string = "/tokenize_and_sequence"
44
47
PredictPath string = "/v1/models/%s/versions/%s:predict"
@@ -55,6 +58,12 @@ type Plugin struct {
55
58
ServingAPIAddress string
56
59
ModelName string
57
60
ModelVersion string
61
+ ResponseType string
62
+ ErrorMessage string
63
+ ErrorSeverity string
64
+ ErrorNumber string
65
+ ErrorDetail string
66
+ LogLevel string
58
67
}
59
68
60
69
type InjectionDetectionPlugin struct {
@@ -139,7 +148,7 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
139
148
if err != nil {
140
149
p .Logger .Error ("Failed to make POST request" , ErrorField , err )
141
150
if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
142
- return p .errorResponse (
151
+ return p .prepareResponse (
143
152
req ,
144
153
map [string ]any {
145
154
QueryField : queryString ,
@@ -163,7 +172,7 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
163
172
if err != nil {
164
173
p .Logger .Error ("Failed to make POST request" , ErrorField , err )
165
174
if p .isSQLi (queryString ) && ! p .LibinjectionPermissiveMode {
166
- return p .errorResponse (
175
+ return p .prepareResponse (
167
176
req ,
168
177
map [string ]any {
169
178
QueryField : queryString ,
@@ -189,8 +198,8 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
189
198
}
190
199
191
200
Detections .With (map [string ]string {DetectorField : DeepLearningModel }).Inc ()
192
- p .Logger .Warn (DetectionMessage , ScoreField , score , DetectorField , DeepLearningModel )
193
- return p .errorResponse (
201
+ p .Logger .Warn (p . ErrorMessage , ScoreField , score , DetectorField , DeepLearningModel )
202
+ return p .prepareResponse (
194
203
req ,
195
204
map [string ]any {
196
205
QueryField : queryString ,
@@ -200,8 +209,8 @@ func (p *Plugin) OnTrafficFromClient(ctx context.Context, req *v1.Struct) (*v1.S
200
209
), nil
201
210
} else if p .EnableLibinjection && injection && ! p .LibinjectionPermissiveMode {
202
211
Detections .With (map [string ]string {DetectorField : Libinjection }).Inc ()
203
- p .Logger .Warn (DetectionMessage , DetectorField , Libinjection )
204
- return p .errorResponse (
212
+ p .Logger .Warn (p . ErrorMessage , DetectorField , Libinjection )
213
+ return p .prepareResponse (
205
214
req ,
206
215
map [string ]any {
207
216
QueryField : queryString ,
@@ -224,35 +233,36 @@ func (p *Plugin) isSQLi(query string) bool {
224
233
// Check if the query is an SQL injection using libinjection.
225
234
injection , _ := libinjection .IsSQLi (query )
226
235
if injection {
227
- p .Logger .Warn (DetectionMessage , DetectorField , Libinjection )
236
+ p .Logger .Warn (p . ErrorMessage , DetectorField , Libinjection )
228
237
}
229
238
p .Logger .Trace ("SQLInjection" , IsInjectionField , cast .ToString (injection ))
230
239
return injection
231
240
}
232
241
233
- func (p * Plugin ) errorResponse (req * v1.Struct , fields map [string ]any ) * v1.Struct {
234
- Preventions .Inc ()
242
+ func (p * Plugin ) prepareResponse (req * v1.Struct , fields map [string ]any ) * v1.Struct {
243
+ Preventions .With (prometheus. Labels { ResponseTypeField : p . ResponseType }). Inc ()
235
244
236
- // Create a PostgreSQL error response.
237
- errResp := postgres .ErrorResponse (
238
- DetectionMessage ,
239
- ExceptionLevel ,
240
- ErrorNumber ,
241
- ErrorResponseMessage ,
242
- )
245
+ var encapsulatedResponse []byte
243
246
244
- // Create a ready for query response.
245
- readyForQuery := & pgproto3.ReadyForQuery {TxStatus : 'I' }
246
- // TODO: Decide whether to terminate the connection.
247
- response , err := readyForQuery .Encode (errResp )
248
- if err != nil {
249
- p .Logger .Error ("Failed to encode ready for query response" , ErrorField , err )
250
- return req
247
+ if p .ResponseType == "error" {
248
+ // Create a PostgreSQL error response.
249
+ encapsulatedResponse = postgres .ErrorResponse (
250
+ p .ErrorMessage ,
251
+ p .ErrorSeverity ,
252
+ ErrorNumber ,
253
+ ErrorDetail ,
254
+ )
255
+ } else {
256
+ // Create a PostgreSQL empty query response.
257
+ encapsulatedResponse , _ = (& pgproto3.EmptyQueryResponse {}).Encode (nil )
251
258
}
252
259
260
+ // Create and encode a ready for query response.
261
+ response , _ := (& pgproto3.ReadyForQuery {TxStatus : 'I' }).Encode (encapsulatedResponse )
262
+
253
263
signals , err := v1 .NewList ([]any {
254
264
sdkAct .Terminate ().ToMap (),
255
- sdkAct .Log (ErrorLevel , DetectionMessage , fields ).ToMap (),
265
+ sdkAct .Log (p . LogLevel , p . ErrorMessage , fields ).ToMap (),
256
266
})
257
267
if err != nil {
258
268
p .Logger .Error ("Failed to create signals" , ErrorField , err )
0 commit comments