Skip to content

Commit e3f3943

Browse files
authored
CSHARP-5017: Retry KMS requests on transient errors (#1541)
1 parent c2de507 commit e3f3943

File tree

8 files changed

+254
-40
lines changed

8 files changed

+254
-40
lines changed

src/MongoDB.Driver.Encryption/CryptClientFactory.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ public static CryptClient Create(CryptOptions options)
163163

164164
Library.mongocrypt_setopt_use_need_kms_credentials_state(handle);
165165

166+
Library.mongocrypt_setopt_retry_kms(handle, true);
167+
166168
Library.mongocrypt_init(handle);
167169

168170
if (options.IsCryptSharedLibRequired)

src/MongoDB.Driver.Encryption/CryptContext.cs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,12 @@ public Binary FinalizeForEncryption()
156156
}
157157

158158
/// <summary>
159-
/// Gets a collection of KMS message requests to make
159+
/// Gets the next KMS message request
160160
/// </summary>
161-
/// <returns>Collection of KMS Messages</returns>
162-
public KmsRequestCollection GetKmsMessageRequests()
161+
public KmsRequest GetNextKmsMessageRequest()
163162
{
164-
var requests = new List<KmsRequest>();
165-
for (IntPtr request = Library.mongocrypt_ctx_next_kms_ctx(_handle); request != IntPtr.Zero; request = Library.mongocrypt_ctx_next_kms_ctx(_handle))
166-
{
167-
requests.Add(new KmsRequest(request));
168-
}
169-
170-
return new KmsRequestCollection(requests, this);
163+
var request = Library.mongocrypt_ctx_next_kms_ctx(_handle);
164+
return request == IntPtr.Zero ? null : new KmsRequest(request);
171165
}
172166

173167
/// <summary>

src/MongoDB.Driver.Encryption/KmsRequest.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ public string KmsProvider
7777
}
7878
}
7979

80+
/// <summary>
81+
/// The number of milliseconds to wait before sending this request.
82+
/// </summary>
83+
public int Sleep => (int)(Library.mongocrypt_kms_ctx_usleep(_id) / 1000);
84+
8085
/// <summary>
8186
/// Gets the message to send to KMS.
8287
/// </summary>
@@ -88,6 +93,12 @@ public Binary GetMessage()
8893
return binary;
8994
}
9095

96+
/// <summary>
97+
/// Indicates a network-level failure.
98+
/// </summary>
99+
/// <returns>A boolean indicating whether the failed request may be retried.</returns>
100+
public bool Fail() => Library.mongocrypt_kms_ctx_fail(_id);
101+
91102
/// <summary>
92103
/// Feeds the response back to the libmongocrypt
93104
/// </summary>

src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)
211211

212212
private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken)
213213
{
214-
var requests = context.GetKmsMessageRequests();
215-
foreach (var request in requests)
214+
while (context.GetNextKmsMessageRequest() is { } request)
216215
{
217216
SendKmsRequest(request, cancellationToken);
218217
}
219-
requests.MarkDone();
218+
context.MarkKmsDone();
220219
}
221220

222221
private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken)
223222
{
224-
var requests = context.GetKmsMessageRequests();
225-
foreach (var request in requests)
223+
while (context.GetNextKmsMessageRequest() is { } request)
226224
{
227225
await SendKmsRequestAsync(request, cancellationToken).ConfigureAwait(false);
228226
}
229-
requests.MarkDone();
227+
context.MarkKmsDone();
230228
}
231229

232230
private void ProcessNeedMongoKeysState(CryptContext context, CancellationToken cancellationToken)
@@ -278,48 +276,90 @@ private static byte[] ProcessReadyState(CryptContext context)
278276

279277
private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
280278
{
281-
var endpoint = CreateKmsEndPoint(request.Endpoint);
282-
283-
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
284-
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
285-
using (var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation))
286-
using (var binary = request.GetMessage())
279+
try
287280
{
281+
var endpoint = CreateKmsEndPoint(request.Endpoint);
282+
283+
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
284+
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
285+
using var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation);
286+
287+
var sleepMs = request.Sleep;
288+
if (sleepMs > 0)
289+
{
290+
Thread.Sleep(sleepMs);
291+
}
292+
293+
using var binary = request.GetMessage();
288294
var requestBytes = binary.ToArray();
289295
sslStream.Write(requestBytes, 0, requestBytes.Length);
290296

291297
while (request.BytesNeeded > 0)
292298
{
293299
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
294300
var count = sslStream.Read(buffer, 0, buffer.Length);
301+
302+
if (count == 0)
303+
{
304+
throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
305+
}
306+
295307
var responseBytes = new byte[count];
296308
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
297309
request.Feed(responseBytes);
298310
}
299311
}
312+
catch (Exception ex) when (ex is IOException or SocketException)
313+
{
314+
if (!request.Fail())
315+
{
316+
throw;
317+
}
318+
}
300319
}
301320

302321
private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation)
303322
{
304-
var endpoint = CreateKmsEndPoint(request.Endpoint);
305-
306-
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
307-
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
308-
using (var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false))
309-
using (var binary = request.GetMessage())
323+
try
310324
{
325+
var endpoint = CreateKmsEndPoint(request.Endpoint);
326+
327+
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
328+
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
329+
using var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false);
330+
331+
var sleepMs = request.Sleep;
332+
if (sleepMs > 0)
333+
{
334+
await Task.Delay(sleepMs, cancellation).ConfigureAwait(false);
335+
}
336+
337+
using var binary = request.GetMessage();
311338
var requestBytes = binary.ToArray();
312339
await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false);
313340

314341
while (request.BytesNeeded > 0)
315342
{
316343
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
317344
var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
345+
346+
if (count == 0)
347+
{
348+
throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
349+
}
350+
318351
var responseBytes = new byte[count];
319352
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
320353
request.Feed(responseBytes);
321354
}
322355
}
356+
catch (Exception ex) when (ex is IOException or SocketException)
357+
{
358+
if (!request.Fail())
359+
{
360+
throw;
361+
}
362+
}
323363
}
324364

325365
// nested type

src/MongoDB.Driver.Encryption/Library.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ static Library()
147147
_mongocrypt_ctx_setopt_query_type = new Lazy<Delegates.mongocrypt_ctx_setopt_query_type>(
148148
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_setopt_query_type>(
149149
("mongocrypt_ctx_setopt_query_type")), true);
150+
_mongocrypt_setopt_retry_kms = new Lazy<Delegates.mongocrypt_setopt_retry_kms>(
151+
() => __loader.Value.GetFunction<Delegates.mongocrypt_setopt_retry_kms>(
152+
("mongocrypt_setopt_retry_kms")), true);
150153

151154
_mongocrypt_ctx_status = new Lazy<Delegates.mongocrypt_ctx_status>(
152155
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_status>(("mongocrypt_ctx_status")), true);
@@ -210,6 +213,11 @@ static Library()
210213
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_destroy>(("mongocrypt_ctx_destroy")), true);
211214
_mongocrypt_kms_ctx_get_kms_provider = new Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider>(
212215
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_get_kms_provider>(("mongocrypt_kms_ctx_get_kms_provider")), true);
216+
217+
_mongocrypt_kms_ctx_usleep = new Lazy<Delegates.mongocrypt_kms_ctx_usleep>(
218+
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_usleep>(("mongocrypt_kms_ctx_usleep")), true);
219+
_mongocrypt_kms_ctx_fail = new Lazy<Delegates.mongocrypt_kms_ctx_fail>(
220+
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_fail>(("mongocrypt_kms_ctx_fail")), true);
213221
}
214222

215223
/// <summary>
@@ -287,6 +295,7 @@ public static string Version
287295
internal static Delegates.mongocrypt_ctx_setopt_algorithm_range mongocrypt_ctx_setopt_algorithm_range => _mongocrypt_ctx_setopt_algorithm_range.Value;
288296
internal static Delegates.mongocrypt_ctx_setopt_contention_factor mongocrypt_ctx_setopt_contention_factor => _mongocrypt_ctx_setopt_contention_factor.Value;
289297
internal static Delegates.mongocrypt_ctx_setopt_query_type mongocrypt_ctx_setopt_query_type => _mongocrypt_ctx_setopt_query_type.Value;
298+
internal static Delegates.mongocrypt_setopt_retry_kms mongocrypt_setopt_retry_kms => _mongocrypt_setopt_retry_kms.Value;
290299

291300
internal static Delegates.mongocrypt_ctx_state mongocrypt_ctx_state => _mongocrypt_ctx_state.Value;
292301
internal static Delegates.mongocrypt_ctx_mongo_op mongocrypt_ctx_mongo_op => _mongocrypt_ctx_mongo_op.Value;
@@ -305,6 +314,9 @@ public static string Version
305314
internal static Delegates.mongocrypt_ctx_destroy mongocrypt_ctx_destroy => _mongocrypt_ctx_destroy.Value;
306315
internal static Delegates.mongocrypt_kms_ctx_get_kms_provider mongocrypt_kms_ctx_get_kms_provider => _mongocrypt_kms_ctx_get_kms_provider.Value;
307316

317+
internal static Delegates.mongocrypt_kms_ctx_usleep mongocrypt_kms_ctx_usleep => _mongocrypt_kms_ctx_usleep.Value;
318+
internal static Delegates.mongocrypt_kms_ctx_fail mongocrypt_kms_ctx_fail => _mongocrypt_kms_ctx_fail.Value;
319+
308320
private static readonly Lazy<LibraryLoader> __loader = new Lazy<LibraryLoader>(
309321
() => new LibraryLoader(), true);
310322
private static readonly Lazy<Delegates.mongocrypt_version> _mongocrypt_version;
@@ -392,6 +404,10 @@ public static string Version
392404
private static readonly Lazy<Delegates.mongocrypt_ctx_destroy> _mongocrypt_ctx_destroy;
393405
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider> _mongocrypt_kms_ctx_get_kms_provider;
394406

407+
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_usleep> _mongocrypt_kms_ctx_usleep;
408+
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_fail> _mongocrypt_kms_ctx_fail;
409+
private static readonly Lazy<Delegates.mongocrypt_setopt_retry_kms> _mongocrypt_setopt_retry_kms;
410+
395411
// nested types
396412
internal enum StatusType
397413
{
@@ -640,6 +656,9 @@ public delegate bool
640656
[return: MarshalAs(UnmanagedType.I1)]
641657
public delegate bool mongocrypt_ctx_setopt_query_type(ContextSafeHandle ctx, [MarshalAs(UnmanagedType.LPStr)] string query_type, int length);
642658

659+
[return: MarshalAs(UnmanagedType.I1)]
660+
public delegate bool mongocrypt_setopt_retry_kms(MongoCryptSafeHandle handle, bool enable);
661+
643662
public delegate CryptContext.StateCode mongocrypt_ctx_state(ContextSafeHandle handle);
644663

645664
[return: MarshalAs(UnmanagedType.I1)]
@@ -681,6 +700,11 @@ public delegate bool
681700

682701
public delegate void mongocrypt_ctx_destroy(IntPtr ptr);
683702
public delegate IntPtr mongocrypt_kms_ctx_get_kms_provider(IntPtr handle, out uint length);
703+
704+
public delegate long mongocrypt_kms_ctx_usleep(IntPtr handle);
705+
706+
[return: MarshalAs(UnmanagedType.I1)]
707+
public delegate bool mongocrypt_kms_ctx_fail(IntPtr handle);
684708
}
685709
}
686710
}

src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
<Target Name="DownloadNativeBinaries_MacOS" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/osx/native/libmongocrypt.dylib')">
1717
<PropertyGroup>
18-
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
18+
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
1919
<LibMongoCryptSourcePath>lib/libmongocrypt.dylib</LibMongoCryptSourcePath>
2020
<LibMongoCryptPackagePath>runtimes/osx/native</LibMongoCryptPackagePath>
2121
</PropertyGroup>
@@ -27,7 +27,7 @@
2727

2828
<Target Name="DownloadNativeBinaries_UbuntuX64" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/linux/native/x64/libmongocrypt.so')">
2929
<PropertyGroup>
30-
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
30+
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
3131
<LibMongoCryptSourcePath>nocrypto/lib/libmongocrypt.so</LibMongoCryptSourcePath>
3232
<LibMongoCryptPackagePath>runtimes/linux/native/x64</LibMongoCryptPackagePath>
3333
</PropertyGroup>
@@ -39,7 +39,7 @@
3939

4040
<Target Name="DownloadNativeBinaries_UbuntuARM64" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/linux/native/arm64/libmongocrypt.so')">
4141
<PropertyGroup>
42-
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
42+
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
4343
<LibMongoCryptSourcePath>nocrypto/lib/libmongocrypt.so</LibMongoCryptSourcePath>
4444
<LibMongoCryptPackagePath>runtimes/linux/native/arm64</LibMongoCryptPackagePath>
4545
</PropertyGroup>
@@ -51,7 +51,7 @@
5151

5252
<Target Name="DownloadNativeBinaries_Alpine" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/linux/native/alpine/libmongocrypt.so')">
5353
<PropertyGroup>
54-
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
54+
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
5555
<LibMongoCryptSourcePath>nocrypto/lib/libmongocrypt.so</LibMongoCryptSourcePath>
5656
<LibMongoCryptPackagePath>runtimes/linux/native/alpine</LibMongoCryptPackagePath>
5757
</PropertyGroup>
@@ -63,7 +63,7 @@
6363

6464
<Target Name="DownloadNativeBinaries_Windows" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/win/native/mongocrypt.dll')">
6565
<PropertyGroup>
66-
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
66+
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
6767
<LibMongoCryptSourcePath>bin/mongocrypt.dll</LibMongoCryptSourcePath>
6868
<LibMongoCryptPackagePath>runtimes/win/native</LibMongoCryptPackagePath>
6969
</PropertyGroup>

tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ public void TestGetKmsProviderName(string kmsName)
431431
using (var cryptClient = CryptClientFactory.Create(cryptOptions))
432432
using (var context = cryptClient.StartCreateDataKeyContext(keyId))
433433
{
434-
var request = context.GetKmsMessageRequests().Single();
434+
var request = context.GetNextKmsMessageRequest();
435435
request.KmsProvider.Should().Be(kmsName);
436436
}
437437
}
@@ -632,22 +632,21 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs
632632

633633
case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS:
634634
{
635-
var requests = context.GetKmsMessageRequests();
636-
foreach (var req in requests)
635+
while (context.GetNextKmsMessageRequest() is { } request)
637636
{
638-
using var binary = req.GetMessage();
637+
using var binary = request.GetMessage();
639638
_output.WriteLine("Key Document: " + binary);
640639
var postRequest = binary.ToString();
641640
// TODO: add different hosts handling
642641
postRequest.Should().Contain("Host:kms.us-east-1.amazonaws.com"); // only AWS
643642

644643
var reply = ReadHttpTestFile(isKmsDecrypt ? "kms-decrypt-reply.txt" : "kms-encrypt-reply.txt");
645644
_output.WriteLine("Reply: " + reply);
646-
req.Feed(Encoding.UTF8.GetBytes(reply));
647-
req.BytesNeeded.Should().Be(0);
645+
request.Feed(Encoding.UTF8.GetBytes(reply));
646+
request.BytesNeeded.Should().Be(0);
648647
}
649648

650-
requests.MarkDone();
649+
context.MarkKmsDone();
651650
return (CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS, null, null);
652651
}
653652

0 commit comments

Comments
 (0)