diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index c2b1f6fa91..3e12b61a61 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -36,15 +36,6 @@ public sealed partial class SqlCommand : DbCommand, ICloneable private const int MaxRPCNameLength = 1046; internal readonly int ObjectID = Interlocked.Increment(ref _objectTypeCount); private string _commandText; - private static readonly Func s_beginExecuteReaderAsync = BeginExecuteReaderAsyncCallback; - private static readonly Func s_endExecuteReaderAsync = EndExecuteReaderAsyncCallback; - private static readonly Action> s_cleanupExecuteReaderAsync = CleanupExecuteReaderAsyncCallback; - private static readonly Func s_internalEndExecuteNonQuery = InternalEndExecuteNonQueryCallback; - private static readonly Func s_internalEndExecuteReader = InternalEndExecuteReaderCallback; - private static readonly Func s_beginExecuteReaderInternal = BeginExecuteReaderInternalCallback; - private static readonly Func s_beginExecuteXmlReaderInternal = BeginExecuteXmlReaderInternalCallback; - private static readonly Func s_beginExecuteNonQueryInternal = BeginExecuteNonQueryInternalCallback; - internal sealed class ExecuteReaderAsyncCallContext : AAsyncCallContext { public Guid OperationID; @@ -93,7 +84,31 @@ protected override void Clear() protected override void AfterCleared(SqlCommand owner) { - + owner?.SetCachedCommandExecuteNonQueryAsyncContext(this); + } + } + + internal sealed class ExecuteXmlReaderAsyncCallContext : AAsyncCallContext + { + public Guid OperationID; + + public SqlCommand Command => _owner; + public TaskCompletionSource TaskCompletionSource => _source; + + public void Set(SqlCommand command, TaskCompletionSource source, CancellationTokenRegistration disposable, Guid operationID) + { + base.Set(command, source, disposable); + OperationID = operationID; + } + + protected override void Clear() + { + OperationID = default; + } + + protected override void AfterCleared(SqlCommand owner) + { + owner?.SetCachedCommandExecuteXmlReaderContext(this); } } @@ -1307,7 +1322,27 @@ private IAsyncResult BeginExecuteNonQueryInternal(CommandBehavior behavior, Asyn // When we use query caching for parameter encryption we need to retry on specific errors. // In these cases finalize the call internally and trigger a retry when needed. - if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, s_internalEndExecuteNonQuery, s_beginExecuteNonQueryInternal, nameof(EndExecuteNonQuery))) + if ( + !TriggerInternalEndAndRetryIfNecessary( + behavior, + stateObject, + timeout, + usedCache, + inRetry, + asyncWrite, + globalCompletion, + localCompletion, + endFunc: static (SqlCommand command, IAsyncResult asyncResult, bool isInternal, string endMethod) => + { + return command.InternalEndExecuteNonQuery(asyncResult, isInternal, endMethod); + }, + retryFunc: static (SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) => + { + return command.BeginExecuteNonQueryInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); + }, + nameof(EndExecuteNonQuery) + ) + ) { globalCompletion = localCompletion; } @@ -1316,7 +1351,7 @@ private IAsyncResult BeginExecuteNonQueryInternal(CommandBehavior behavior, Asyn if (callback != null) { globalCompletion.Task.ContinueWith( - static (task, state) => ((AsyncCallback)state)(task), + static (Task task, object state) => ((AsyncCallback)state)(task), state: callback ); } @@ -1817,7 +1852,27 @@ private IAsyncResult BeginExecuteXmlReaderInternal(CommandBehavior behavior, Asy // When we use query caching for parameter encryption we need to retry on specific errors. // In these cases finalize the call internally and trigger a retry when needed. - if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, s_internalEndExecuteReader, s_beginExecuteXmlReaderInternal, endMethod: nameof(EndExecuteXmlReader))) + if ( + !TriggerInternalEndAndRetryIfNecessary( + behavior, + stateObject, + timeout, + usedCache, + inRetry, + asyncWrite, + globalCompletion, + localCompletion, + endFunc: static (SqlCommand command, IAsyncResult asyncResult, bool isInternal, string endMethod) => + { + return command.InternalEndExecuteReader(asyncResult, isInternal, endMethod); + }, + retryFunc: static (SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) => + { + return command.BeginExecuteXmlReaderInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); + }, + endMethod: nameof(EndExecuteXmlReader) + ) + ) { globalCompletion = localCompletion; } @@ -1826,7 +1881,7 @@ private IAsyncResult BeginExecuteXmlReaderInternal(CommandBehavior behavior, Asy if (callback != null) { localCompletion.Task.ContinueWith( - static (task, state) => ((AsyncCallback)state)(task), + static (Task task, object state) => ((AsyncCallback)state)(task), state: callback ); } @@ -2179,29 +2234,6 @@ private void CleanupExecuteReaderAsync(Task task, TaskCompletionS } } - private static IAsyncResult BeginExecuteReaderAsyncCallback(AsyncCallback callback, object stateObject) - { - ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)stateObject; - return args.Command.BeginExecuteReaderInternal(args.CommandBehavior, callback, stateObject, args.Command.CommandTimeout, inRetry: false, asyncWrite: true); - } - - private static SqlDataReader EndExecuteReaderAsyncCallback(IAsyncResult asyncResult) - { - ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)asyncResult.AsyncState; - return args.Command.EndExecuteReaderAsync(asyncResult); - } - - private static void CleanupExecuteReaderAsyncCallback(Task task) - { - ExecuteReaderAsyncCallContext context = (ExecuteReaderAsyncCallContext)task.AsyncState; - SqlCommand command = context.Command; - Guid operationId = context.OperationID; - TaskCompletionSource source = context.TaskCompletionSource; - context.Dispose(); - - command.CleanupExecuteReaderAsync(task, source, operationId); - } - private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite = false) { TaskCompletionSource globalCompletion = new TaskCompletionSource(stateObject); @@ -2266,7 +2298,27 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC // When we use query caching for parameter encryption we need to retry on specific errors. // In these cases finalize the call internally and trigger a retry when needed. - if (!TriggerInternalEndAndRetryIfNecessary(behavior, stateObject, timeout, usedCache, inRetry, asyncWrite, globalCompletion, localCompletion, s_internalEndExecuteReader, s_beginExecuteReaderInternal, nameof(EndExecuteReader))) + if ( + !TriggerInternalEndAndRetryIfNecessary( + behavior, + stateObject, + timeout, + usedCache, + inRetry, + asyncWrite, + globalCompletion, + localCompletion, + endFunc: static (SqlCommand command, IAsyncResult asyncResult, bool isInternal, string endMethod) => + { + return command.InternalEndExecuteReader(asyncResult, isInternal, endMethod); + }, + retryFunc: static (SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) => + { + return command.BeginExecuteReaderInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); + }, + nameof(EndExecuteReader) + ) + ) { globalCompletion = localCompletion; } @@ -2275,7 +2327,7 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC if (callback != null) { globalCompletion.Task.ContinueWith( - static (task, state) => ((AsyncCallback)state)(task), + static (Task task, object state) => ((AsyncCallback)state)(task), state: callback ); } @@ -2288,42 +2340,6 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC } } - /// - /// used to convert an invocation through a cached static delegate back to an instance call - /// - private static SqlDataReader InternalEndExecuteReaderCallback(SqlCommand command, IAsyncResult asyncResult, bool isInternal, string endMethod) - { - return command.InternalEndExecuteReader(asyncResult, isInternal, endMethod); - } - /// - /// used to convert an invocation through a cached static delegate back to an instance call - /// - private static IAsyncResult BeginExecuteReaderInternalCallback(SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) - { - return command.BeginExecuteReaderInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); - } - /// - /// used to convert an invocation through a cached static delegate back to an instance call - /// - private static IAsyncResult BeginExecuteXmlReaderInternalCallback(SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) - { - return command.BeginExecuteXmlReaderInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); - } - /// - /// used to convert an invocation through a cached static delegate back to an instance call - /// - private static object InternalEndExecuteNonQueryCallback(SqlCommand command, IAsyncResult asyncResult, bool isInternal, string endMethod) - { - return command.InternalEndExecuteNonQuery(asyncResult, isInternal, endMethod); - } - /// - /// used to convert an invocation through a cached static delegate back to an instance call - /// - private static IAsyncResult BeginExecuteNonQueryInternalCallback(SqlCommand command, CommandBehavior behavior, AsyncCallback callback, object stateObject, int timeout, bool inRetry, bool asyncWrite) - { - return command.BeginExecuteNonQueryInternal(behavior, callback, stateObject, timeout, inRetry, asyncWrite); - } - private bool TriggerInternalEndAndRetryIfNecessary( CommandBehavior behavior, object stateObject, @@ -2597,41 +2613,27 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok try { Task.Factory.FromAsync( - static (AsyncCallback callback, object stateObject) => ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject), - static (IAsyncResult result) => ((ExecuteNonQueryAsyncCallContext)result.AsyncState).Command.EndExecuteNonQueryAsync(result), + beginMethod: static (AsyncCallback callback, object stateObject) => // with c# 10/NET6 add [StackTraceHidden] to this + { + return ((ExecuteNonQueryAsyncCallContext)stateObject).Command.BeginExecuteNonQueryAsync(callback, stateObject); + }, + endMethod: static (IAsyncResult asyncResult) => // with c# 10/NET6 add [StackTraceHidden] to this + { + return ((ExecuteNonQueryAsyncCallContext)asyncResult.AsyncState).Command.EndExecuteNonQueryAsync(asyncResult); + }, state: context - ).ContinueWith( - static (Task task, object state) => + ) + .ContinueWith( + static (Task task) => // with c# 10/NET6 add [StackTraceHidden] to this { - ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)state; - - Guid operationId = context.OperationID; + ExecuteNonQueryAsyncCallContext context = (ExecuteNonQueryAsyncCallContext)task.AsyncState; SqlCommand command = context.Command; + Guid operationId = context.OperationID; TaskCompletionSource source = context.TaskCompletionSource; - context.Dispose(); - context = null; - if (task.IsFaulted) - { - Exception e = task.Exception.InnerException; - s_diagnosticListener.WriteCommandError(operationId, command, command._transaction, e); - source.SetException(e); - } - else - { - if (task.IsCanceled) - { - source.SetCanceled(); - } - else - { - source.SetResult(task.Result); - } - s_diagnosticListener.WriteCommandAfter(operationId, command, command._transaction); - } + command.CleanupAfterExecuteNonQueryAsync(task, source, operationId); }, - state: context, scheduler: TaskScheduler.Default ); } @@ -2645,6 +2647,28 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok return returnedTask; } + private void CleanupAfterExecuteNonQueryAsync(Task task, TaskCompletionSource source, Guid operationId) + { + if (task.IsFaulted) + { + Exception e = task.Exception.InnerException; + s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); + source.SetException(e); + } + else + { + if (task.IsCanceled) + { + source.SetCanceled(); + } + else + { + source.SetResult(task.Result); + } + s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); + } + } + /// protected override Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) { @@ -2733,12 +2757,29 @@ private Task InternalExecuteReaderAsync(CommandBehavior behavior, context.Set(this, source, registration, behavior, operationId); Task.Factory.FromAsync( - beginMethod: s_beginExecuteReaderAsync, - endMethod: s_endExecuteReaderAsync, + beginMethod: static (AsyncCallback callback, object stateObject) => + { + ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)stateObject; + return args.Command.BeginExecuteReaderInternal(args.CommandBehavior, callback, stateObject, args.Command.CommandTimeout, inRetry: false, asyncWrite: true); + }, + endMethod: static (IAsyncResult asyncResult) => + { + ExecuteReaderAsyncCallContext args = (ExecuteReaderAsyncCallContext)asyncResult.AsyncState; + return args.Command.EndExecuteReaderAsync(asyncResult); + }, state: context ).ContinueWith( - continuationAction: s_cleanupExecuteReaderAsync, - TaskScheduler.Default + continuationAction: static (Task task) => + { + ExecuteReaderAsyncCallContext context = (ExecuteReaderAsyncCallContext)task.AsyncState; + SqlCommand command = context.Command; + Guid operationId = context.OperationID; + TaskCompletionSource source = context.TaskCompletionSource; + context.Dispose(); + + command.CleanupExecuteReaderAsync(task, source, operationId); + }, + scheduler: TaskScheduler.Default ); } catch (Exception e) @@ -2763,6 +2804,22 @@ private void SetCachedCommandExecuteReaderAsyncContext(ExecuteReaderAsyncCallCon } } + private void SetCachedCommandExecuteNonQueryAsyncContext(ExecuteNonQueryAsyncCallContext instance) + { + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + Interlocked.CompareExchange(ref sqlInternalConnection.CachedCommandExecuteNonQueryAsyncContext, instance, null); + } + } + + private void SetCachedCommandExecuteXmlReaderContext(ExecuteXmlReaderAsyncCallContext instance) + { + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + Interlocked.CompareExchange(ref sqlInternalConnection.CachedCommandExecuteXmlReaderAsyncContext, instance, null); + } + } + /// public override Task ExecuteScalarAsync(CancellationToken cancellationToken) => // Do not use retry logic here as internal call to ExecuteReaderAsync handles retry logic. @@ -2891,35 +2948,44 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella registration = cancellationToken.Register(s_cancelIgnoreFailure, this); } + ExecuteXmlReaderAsyncCallContext context = null; + if (_activeConnection?.InnerConnection is SqlInternalConnection sqlInternalConnection) + { + context = Interlocked.Exchange(ref sqlInternalConnection.CachedCommandExecuteXmlReaderAsyncContext, null); + } + if (context is null) + { + context = new ExecuteXmlReaderAsyncCallContext(); + } + context.Set(this, source, registration, operationId); + + Task returnedTask = source.Task; try { returnedTask = RegisterForConnectionCloseNotification(returnedTask); - Task.Factory.FromAsync(BeginExecuteXmlReaderAsync, EndExecuteXmlReaderAsync, null) - .ContinueWith((Task task) => + Task.Factory.FromAsync( + beginMethod: static (AsyncCallback callback, object stateObject) => // with c# 10/NET6 add [StackTraceHidden] to this { - registration.Dispose(); - if (task.IsFaulted) - { - Exception e = task.Exception.InnerException; - s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); - source.SetException(e); - } - else - { - s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); - if (task.IsCanceled) - { - source.SetCanceled(); - } - else - { - source.SetResult(task.Result); - } - - } - }, + return ((ExecuteXmlReaderAsyncCallContext)stateObject).Command.BeginExecuteXmlReaderAsync(callback, stateObject); + }, + endMethod: static (IAsyncResult asyncResult) => // with c# 10/NET6 add [StackTraceHidden] to this + { + return ((ExecuteXmlReaderAsyncCallContext)asyncResult.AsyncState).Command.EndExecuteXmlReaderAsync(asyncResult); + }, + state: context + ).ContinueWith( + static (Task task) => + { + ExecuteXmlReaderAsyncCallContext context = (ExecuteXmlReaderAsyncCallContext)task.AsyncState; + SqlCommand command = context.Command; + Guid operationId = context.OperationID; + TaskCompletionSource source = context.TaskCompletionSource; + context.Dispose(); + + command.CleanupAfterExecuteXmlReaderAsync(task, source, operationId); + }, TaskScheduler.Default ); } @@ -2932,6 +2998,28 @@ private Task InternalExecuteXmlReaderAsync(CancellationToken cancella return returnedTask; } + private void CleanupAfterExecuteXmlReaderAsync(Task task, TaskCompletionSource source, Guid operationId) + { + if (task.IsFaulted) + { + Exception e = task.Exception.InnerException; + s_diagnosticListener.WriteCommandError(operationId, this, _transaction, e); + source.SetException(e); + } + else + { + if (task.IsCanceled) + { + source.SetCanceled(); + } + else + { + source.SetResult(task.Result); + } + s_diagnosticListener.WriteCommandAfter(operationId, this, _transaction); + } + } + /// public void RegisterColumnEncryptionKeyStoreProvidersOnCommand(IDictionary customProviders) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs index a6da618583..e6faafc405 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnection.cs @@ -29,6 +29,9 @@ internal abstract class SqlInternalConnection : DbConnectionInternal #if NETCOREAPP || NETSTANDARD internal SqlCommand.ExecuteReaderAsyncCallContext CachedCommandExecuteReaderAsyncContext; + internal SqlCommand.ExecuteNonQueryAsyncCallContext CachedCommandExecuteNonQueryAsyncContext; + internal SqlCommand.ExecuteXmlReaderAsyncCallContext CachedCommandExecuteXmlReaderAsyncContext; + internal SqlDataReader.Snapshot CachedDataReaderSnapshot; internal SqlDataReader.IsDBNullAsyncCallContext CachedDataReaderIsDBNullContext; internal SqlDataReader.ReadAsyncCallContext CachedDataReaderReadAsyncContext;