1
1
package software .amazon .lambda .powertools .sqs .internal ;
2
2
3
- import org .slf4j .Logger ;
4
- import org .slf4j .LoggerFactory ;
5
-
6
3
import java .util .ArrayList ;
4
+ import java .util .Arrays ;
5
+ import java .util .HashMap ;
7
6
import java .util .List ;
7
+ import java .util .Map ;
8
+ import java .util .Optional ;
8
9
10
+ import com .fasterxml .jackson .core .JsonProcessingException ;
11
+ import com .fasterxml .jackson .databind .JsonNode ;
12
+ import org .slf4j .Logger ;
13
+ import org .slf4j .LoggerFactory ;
14
+ import software .amazon .awssdk .core .SdkBytes ;
9
15
import software .amazon .awssdk .services .sqs .SqsClient ;
10
16
import software .amazon .awssdk .services .sqs .model .DeleteMessageBatchRequest ;
11
17
import software .amazon .awssdk .services .sqs .model .DeleteMessageBatchRequestEntry ;
12
18
import software .amazon .awssdk .services .sqs .model .DeleteMessageBatchResponse ;
13
- import software .amazon .awssdk .services .sqs .model .GetQueueUrlRequest ;
19
+ import software .amazon .awssdk .services .sqs .model .GetQueueAttributesRequest ;
20
+ import software .amazon .awssdk .services .sqs .model .GetQueueAttributesResponse ;
21
+ import software .amazon .awssdk .services .sqs .model .MessageAttributeValue ;
22
+ import software .amazon .awssdk .services .sqs .model .QueueAttributeName ;
23
+ import software .amazon .awssdk .services .sqs .model .SendMessageBatchRequestEntry ;
24
+ import software .amazon .awssdk .services .sqs .model .SendMessageBatchResponse ;
14
25
import software .amazon .lambda .powertools .sqs .SQSBatchProcessingException ;
26
+ import software .amazon .lambda .powertools .sqs .SqsUtils ;
15
27
16
28
import static com .amazonaws .services .lambda .runtime .events .SQSEvent .SQSMessage ;
17
29
import static java .lang .String .format ;
30
+ import static java .util .Optional .ofNullable ;
18
31
import static java .util .stream .Collectors .toList ;
19
32
20
33
public final class BatchContext {
21
34
private static final Logger LOG = LoggerFactory .getLogger (BatchContext .class );
35
+ private static final Map <String , String > QUEUE_ARN_TO_DLQ_URL_MAPPING = new HashMap <>();
22
36
37
+ private final Map <SQSMessage , Exception > messageToException = new HashMap <>();
23
38
private final List <SQSMessage > success = new ArrayList <>();
24
- private final List <SQSMessage > failures = new ArrayList <>();
25
- private final List <Exception > exceptions = new ArrayList <>();
39
+
26
40
private final SqsClient client ;
27
41
28
42
public BatchContext (SqsClient client ) {
@@ -34,53 +48,170 @@ public void addSuccess(SQSMessage event) {
34
48
}
35
49
36
50
public void addFailure (SQSMessage event , Exception e ) {
37
- failures .add (event );
38
- exceptions .add (e );
51
+ messageToException .put (event , e );
39
52
}
40
53
41
- public <T > void processSuccessAndHandleFailed (final List <T > successReturns ,
42
- final boolean suppressException ) {
54
+ @ SafeVarargs
55
+ public final <T > void processSuccessAndHandleFailed (final List <T > successReturns ,
56
+ final boolean suppressException ,
57
+ final boolean deleteNonRetryableMessageFromQueue ,
58
+ final Class <? extends Exception >... nonRetryableExceptions ) {
43
59
if (hasFailures ()) {
44
- deleteSuccessMessage ();
45
60
46
- if (suppressException ) {
47
- List <String > messageIds = failures .stream ().
48
- map (SQSMessage ::getMessageId )
49
- .collect (toList ());
61
+ List <Exception > exceptions = new ArrayList <>();
62
+ List <SQSMessage > failedMessages = new ArrayList <>();
63
+ Map <SQSMessage , Exception > nonRetryableMessageToException = new HashMap <>();
50
64
51
- LOG .debug (format ("[%s] records failed processing, but exceptions are suppressed. " +
52
- "Failed messages %s" , failures .size (), messageIds ));
65
+ if (nonRetryableExceptions .length == 0 ) {
66
+ exceptions .addAll (messageToException .values ());
67
+ failedMessages .addAll (messageToException .keySet ());
53
68
} else {
54
- throw new SQSBatchProcessingException (exceptions , failures , successReturns );
69
+ messageToException .forEach ((sqsMessage , exception ) -> {
70
+ boolean nonRetryableException = isNonRetryableException (exception , nonRetryableExceptions );
71
+
72
+ if (nonRetryableException ) {
73
+ nonRetryableMessageToException .put (sqsMessage , exception );
74
+ } else {
75
+ exceptions .add (exception );
76
+ failedMessages .add (sqsMessage );
77
+ }
78
+ });
79
+ }
80
+
81
+ List <SQSMessage > messagesToBeDeleted = new ArrayList <>(success );
82
+
83
+ if (!nonRetryableMessageToException .isEmpty () && deleteNonRetryableMessageFromQueue ) {
84
+ messagesToBeDeleted .addAll (nonRetryableMessageToException .keySet ());
85
+ } else if (!nonRetryableMessageToException .isEmpty ()) {
86
+
87
+ boolean isMovedToDlq = moveNonRetryableMessagesToDlqIfConfigured (nonRetryableMessageToException );
88
+
89
+ if (!isMovedToDlq ) {
90
+ exceptions .addAll (nonRetryableMessageToException .values ());
91
+ failedMessages .addAll (nonRetryableMessageToException .keySet ());
92
+ }
55
93
}
94
+
95
+ deleteMessagesFromQueue (messagesToBeDeleted );
96
+
97
+ processFailedMessages (successReturns , suppressException , exceptions , failedMessages );
98
+ }
99
+ }
100
+
101
+ private <T > void processFailedMessages (List <T > successReturns ,
102
+ boolean suppressException ,
103
+ List <Exception > exceptions ,
104
+ List <SQSMessage > failedMessages ) {
105
+ if (failedMessages .isEmpty ()) {
106
+ return ;
107
+ }
108
+
109
+ if (suppressException ) {
110
+ List <String > messageIds = failedMessages .stream ().
111
+ map (SQSMessage ::getMessageId )
112
+ .collect (toList ());
113
+
114
+ LOG .debug (format ("[%s] records failed processing, but exceptions are suppressed. " +
115
+ "Failed messages %s" , failedMessages .size (), messageIds ));
116
+ } else {
117
+ throw new SQSBatchProcessingException (exceptions , failedMessages , successReturns );
56
118
}
57
119
}
58
120
121
+ private boolean isNonRetryableException (Exception exception , Class <? extends Exception >[] nonRetryableExceptions ) {
122
+ return Arrays .stream (nonRetryableExceptions )
123
+ .anyMatch (aClass -> aClass .isInstance (exception ));
124
+ }
125
+
126
+ private boolean moveNonRetryableMessagesToDlqIfConfigured (Map <SQSMessage , Exception > nonRetryableMessageToException ) {
127
+ Optional <String > dlqUrl = fetchDlqUrl (nonRetryableMessageToException );
128
+
129
+ if (!dlqUrl .isPresent ()) {
130
+ return false ;
131
+ }
132
+
133
+ List <SendMessageBatchRequestEntry > dlqMessages = nonRetryableMessageToException .keySet ().stream ()
134
+ .map (sqsMessage -> {
135
+ Map <String , MessageAttributeValue > messageAttributesMap = new HashMap <>();
136
+
137
+ sqsMessage .getMessageAttributes ().forEach ((s , messageAttribute ) -> {
138
+ MessageAttributeValue .Builder builder = MessageAttributeValue .builder ();
139
+
140
+ builder
141
+ .dataType (messageAttribute .getDataType ())
142
+ .stringValue (messageAttribute .getStringValue ());
143
+
144
+ if (null != messageAttribute .getBinaryValue ()) {
145
+ builder .binaryValue (SdkBytes .fromByteBuffer (messageAttribute .getBinaryValue ()));
146
+ }
147
+
148
+ messageAttributesMap .put (s , builder .build ());
149
+ });
150
+
151
+ return SendMessageBatchRequestEntry .builder ()
152
+ .messageBody (sqsMessage .getBody ())
153
+ .id (sqsMessage .getMessageId ())
154
+ .messageAttributes (messageAttributesMap )
155
+ .build ();
156
+ })
157
+ .collect (toList ());
158
+
159
+ SendMessageBatchResponse sendMessageBatchResponse = client .sendMessageBatch (builder -> builder .queueUrl (dlqUrl .get ())
160
+ .entries (dlqMessages ));
161
+
162
+ LOG .debug ("Response from send batch message to DLQ request {}" , sendMessageBatchResponse );
163
+
164
+ return true ;
165
+ }
166
+
167
+ private Optional <String > fetchDlqUrl (Map <SQSMessage , Exception > nonRetryableMessageToException ) {
168
+ return nonRetryableMessageToException .keySet ().stream ()
169
+ .findFirst ()
170
+ .map (sqsMessage -> QUEUE_ARN_TO_DLQ_URL_MAPPING .computeIfAbsent (sqsMessage .getEventSourceArn (), sourceArn -> {
171
+ String queueUrl = url (sourceArn );
172
+
173
+ GetQueueAttributesResponse queueAttributes = client .getQueueAttributes (GetQueueAttributesRequest .builder ()
174
+ .attributeNames (QueueAttributeName .REDRIVE_POLICY )
175
+ .queueUrl (queueUrl )
176
+ .build ());
177
+
178
+ return ofNullable (queueAttributes .attributes ().get (QueueAttributeName .REDRIVE_POLICY ))
179
+ .map (policy -> {
180
+ try {
181
+ return SqsUtils .objectMapper ().readTree (policy );
182
+ } catch (JsonProcessingException e ) {
183
+ LOG .debug ("Unable to parse Re drive policy for queue {}. Even if DLQ exists, failed messages will be send back to main queue." , queueUrl , e );
184
+ return null ;
185
+ }
186
+ })
187
+ .map (node -> node .get ("deadLetterTargetArn" ))
188
+ .map (JsonNode ::asText )
189
+ .map (this ::url )
190
+ .orElse (null );
191
+ }));
192
+ }
193
+
59
194
private boolean hasFailures () {
60
- return !failures .isEmpty ();
195
+ return !messageToException .isEmpty ();
61
196
}
62
197
63
- private void deleteSuccessMessage ( ) {
64
- if (!success .isEmpty ()) {
198
+ private void deleteMessagesFromQueue ( final List < SQSMessage > messages ) {
199
+ if (!messages .isEmpty ()) {
65
200
DeleteMessageBatchRequest request = DeleteMessageBatchRequest .builder ()
66
- .queueUrl (url ())
67
- .entries (success .stream ().map (m -> DeleteMessageBatchRequestEntry .builder ()
201
+ .queueUrl (url (messages . get ( 0 ). getEventSourceArn () ))
202
+ .entries (messages .stream ().map (m -> DeleteMessageBatchRequestEntry .builder ()
68
203
.id (m .getMessageId ())
69
204
.receiptHandle (m .getReceiptHandle ())
70
205
.build ()).collect (toList ()))
71
206
.build ();
72
207
73
208
DeleteMessageBatchResponse deleteMessageBatchResponse = client .deleteMessageBatch (request );
74
- LOG .debug (format ( "Response from delete request %s " , deleteMessageBatchResponse ) );
209
+ LOG .debug ("Response from delete request {} " , deleteMessageBatchResponse );
75
210
}
76
211
}
77
212
78
- private String url () {
79
- String [] arnArray = success .get (0 ).getEventSourceArn ().split (":" );
80
- return client .getQueueUrl (GetQueueUrlRequest .builder ()
81
- .queueOwnerAWSAccountId (arnArray [4 ])
82
- .queueName (arnArray [5 ])
83
- .build ())
84
- .queueUrl ();
213
+ private String url (String queueArn ) {
214
+ String [] arnArray = queueArn .split (":" );
215
+ return String .format ("https://sqs.%s.amazonaws.com/%s/%s" , arnArray [3 ], arnArray [4 ], arnArray [5 ]);
85
216
}
86
217
}
0 commit comments