diff --git a/src/Features/CSharpTest/SemanticSearch/CSharpSemanticSearchServiceTests.cs b/src/Features/CSharpTest/SemanticSearch/CSharpSemanticSearchServiceTests.cs index e3ef701703d41..9747ebb970aa4 100644 --- a/src/Features/CSharpTest/SemanticSearch/CSharpSemanticSearchServiceTests.cs +++ b/src/Features/CSharpTest/SemanticSearch/CSharpSemanticSearchServiceTests.cs @@ -7,7 +7,6 @@ using System.Diagnostics; using System.IO; using System.Linq; -using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis.Classification; @@ -64,14 +63,33 @@ public void VisibleMethod(int param) { } """; - [ConditionalFact(typeof(CoreClrOnly))] - public async Task CompilationQuery() + private static async Task VerifyCompileAndExecuteQueryAsync( + TestWorkspace workspace, + string query, + string[] expectedItems) { - using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); + var items = new List(); + var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = items.Add }; var solution = workspace.CurrentSolution; - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); + var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); + var traceSource = new TraceSource("test"); + + var compileResult = service.CompileQuery(solution.Services, query, s_referenceAssembliesDir, traceSource, CancellationToken.None); + Assert.Equal(LanguageNames.CSharp, compileResult.QueryId.Language); + Assert.Empty(compileResult.CompilationErrors); + + var executeResult = await service.ExecuteQueryAsync(solution, compileResult.QueryId, observer, options, traceSource, CancellationToken.None); + Assert.Null(executeResult.ErrorMessage); + + AssertEx.Equal(expectedItems, items.Select(Inspect).OrderBy(s => s)); + } + + [ConditionalFact(typeof(CoreClrOnly))] + public async Task CompilationQuery() + { + using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); var query = """ static IEnumerable Find(Compilation compilation) @@ -80,15 +98,7 @@ static IEnumerable Find(Compilation compilation) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal(["namespace N"], results.Select(Inspect)); + await VerifyCompileAndExecuteQueryAsync(workspace, query, ["namespace N"]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -96,10 +106,6 @@ public async Task NamespaceQuery() { using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static IEnumerable Find(INamespaceSymbol n) { @@ -107,15 +113,7 @@ static IEnumerable Find(INamespaceSymbol n) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal(["class C"], results.Select(Inspect)); + await VerifyCompileAndExecuteQueryAsync(workspace, query, ["class C"]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -123,10 +121,6 @@ public async Task NamedTypeQuery() { using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static IEnumerable Find(INamedTypeSymbol type) { @@ -134,15 +128,7 @@ static IEnumerable Find(INamedTypeSymbol type) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal(["int C.F"], results.Select(Inspect)); + await VerifyCompileAndExecuteQueryAsync(workspace, query, ["int C.F"]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -150,10 +136,6 @@ public async Task MethodQuery() { using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static IEnumerable Find(IMethodSymbol method) { @@ -161,22 +143,14 @@ static IEnumerable Find(IMethodSymbol method) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal( + await VerifyCompileAndExecuteQueryAsync(workspace, query, [ "C.C()", "int C.P.get", "void C.E.add", "void C.E.remove", "void C.VisibleMethod(int)", - ], results.Select(Inspect).OrderBy(s => s)); + ]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -184,10 +158,6 @@ public async Task FieldQuery() { using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static IEnumerable Find(IFieldSymbol field) { @@ -195,19 +165,11 @@ static IEnumerable Find(IFieldSymbol field) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal( + await VerifyCompileAndExecuteQueryAsync(workspace, query, [ "int C.F", "readonly int C.P.field", - ], results.Select(Inspect).OrderBy(s => s)); + ]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -215,10 +177,6 @@ public async Task PropertyQuery() { using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static IEnumerable Find(IPropertySymbol prop) { @@ -226,15 +184,10 @@ static IEnumerable Find(IPropertySymbol prop) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal(["int C.P { get; }"], results.Select(Inspect)); + await VerifyCompileAndExecuteQueryAsync(workspace, query, + [ + "int C.P { get; }" + ]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -242,10 +195,6 @@ public async Task EventQuery() { using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static IEnumerable Find(IEventSymbol e) { @@ -253,15 +202,10 @@ static IEnumerable Find(IEventSymbol e) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal(["event Action C.E"], results.Select(Inspect)); + await VerifyCompileAndExecuteQueryAsync(workspace, query, + [ + "event Action C.E" + ]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -288,10 +232,6 @@ class D """, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static async IAsyncEnumerable Find(IMethodSymbol e) { @@ -308,19 +248,11 @@ static async IAsyncEnumerable Find(IMethodSymbol e) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - AssertEx.Equal( + await VerifyCompileAndExecuteQueryAsync(workspace, query, [ "void D.R1()", "void D.R2()" - ], results.Select(Inspect)); + ]); } [ConditionalFact(typeof(CoreClrOnly))] @@ -328,10 +260,6 @@ public async Task NullReturn() { using var workspace = TestWorkspace.Create(DefaultWorkspaceXml, composition: FeaturesTestCompositions.Features); - var solution = workspace.CurrentSolution; - - var service = solution.Services.GetRequiredLanguageService(LanguageNames.CSharp); - var query = """ static IEnumerable Find(Compilation compilation) { @@ -339,15 +267,7 @@ static IEnumerable Find(Compilation compilation) } """; - var results = new List(); - var observer = new MockSemanticSearchResultsObserver() { OnDefinitionFoundImpl = results.Add }; - var traceSource = new TraceSource("test"); - - var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); - - Assert.Null(result.ErrorMessage); - Assert.Empty(results); + await VerifyCompileAndExecuteQueryAsync(workspace, query, expectedItems: []); } [ConditionalFact(typeof(CoreClrOnly))] @@ -384,8 +304,11 @@ static IEnumerable Find(Compilation compilation) var traceSource = new TraceSource("test"); var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); + var compileResult = service.CompileQuery(solution.Services, query, s_referenceAssembliesDir, traceSource, CancellationToken.None); + Assert.Empty(compileResult.CompilationErrors); + await Assert.ThrowsAsync( - () => service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, cancellationSource.Token)); + () => service.ExecuteQueryAsync(solution, compileResult.QueryId, observer, options, traceSource, cancellationSource.Token)); Assert.Empty(exceptions); } @@ -430,7 +353,10 @@ void F(long x) var traceSource = new TraceSource("test"); var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); + var compileResult = service.CompileQuery(solution.Services, query, s_referenceAssembliesDir, traceSource, CancellationToken.None); + Assert.Empty(compileResult.CompilationErrors); + + var result = await service.ExecuteQueryAsync(solution, compileResult.QueryId, observer, options, traceSource, CancellationToken.None); var expectedMessage = new InsufficientExecutionStackException().Message; AssertEx.Equal(string.Format(FeaturesResources.Semantic_search_query_terminated_with_exception, "CSharpAssembly1", expectedMessage), result.ErrorMessage); @@ -496,7 +422,10 @@ static ISymbol F(ISymbol s) var traceSource = new TraceSource("test"); var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); + var compileResult = service.CompileQuery(solution.Services, query, s_referenceAssembliesDir, traceSource, CancellationToken.None); + Assert.Empty(compileResult.CompilationErrors); + + var result = await service.ExecuteQueryAsync(solution, compileResult.QueryId, observer, options, traceSource, CancellationToken.None); var expectedMessage = new NullReferenceException().Message; AssertEx.Equal(string.Format(FeaturesResources.Semantic_search_query_terminated_with_exception, "CSharpAssembly1", expectedMessage), result.ErrorMessage); @@ -553,7 +482,11 @@ static IEnumerable Find(IMethodSymbol method) var traceSource = new TraceSource("test"); var options = workspace.GlobalOptions.GetClassificationOptionsProvider(); - var result = await service.ExecuteQueryAsync(solution, query, s_referenceAssembliesDir, observer, options, traceSource, CancellationToken.None); + + var compileResult = service.CompileQuery(solution.Services, query, s_referenceAssembliesDir, traceSource, CancellationToken.None); + Assert.Empty(compileResult.CompilationErrors); + + var result = await service.ExecuteQueryAsync(solution, compileResult.QueryId, observer, options, traceSource, CancellationToken.None); Assert.Null(result.ErrorMessage); AssertEx.Equal(["void C.VisibleMethod(int)"], results.Select(Inspect)); diff --git a/src/Features/Core/Portable/SemanticSearch/AbstractSemanticSearchService.cs b/src/Features/Core/Portable/SemanticSearch/AbstractSemanticSearchService.cs index 3a82b10666c14..f415365cf24f8 100644 --- a/src/Features/Core/Portable/SemanticSearch/AbstractSemanticSearchService.cs +++ b/src/Features/Core/Portable/SemanticSearch/AbstractSemanticSearchService.cs @@ -54,6 +54,19 @@ protected override IntPtr LoadUnmanagedDll(string unmanagedDllName) => IntPtr.Zero; } + private readonly struct CompiledQuery(MemoryStream peStream, MemoryStream pdbStream, SourceText text) : IDisposable + { + public MemoryStream PEStream { get; } = peStream; + public MemoryStream PdbStream { get; } = pdbStream; + public SourceText Text { get; } = text; + + public void Dispose() + { + PEStream.Dispose(); + PdbStream.Dispose(); + } + } + /// /// Mapping from the parameter type of the Find method to the value. /// @@ -66,76 +79,93 @@ protected override IntPtr LoadUnmanagedDll(string unmanagedDllName) .Add(typeof(IPropertySymbol), QueryKind.Property) .Add(typeof(IEventSymbol), QueryKind.Event); + private ImmutableDictionary _compiledQueries = ImmutableDictionary.Empty; + protected abstract Compilation CreateCompilation(SourceText query, IEnumerable references, SolutionServices services, out SyntaxTree queryTree, CancellationToken cancellationToken); - public async Task ExecuteQueryAsync( - Solution solution, + public CompileQueryResult CompileQuery( + SolutionServices services, string query, string referenceAssembliesDir, - ISemanticSearchResultsObserver observer, - OptionsProvider classificationOptions, TraceSource traceSource, CancellationToken cancellationToken) { - try - { - // add progress items - one for compilation, one for emit and one for each project: - var remainingProgressItemCount = 2 + solution.ProjectIds.Count; - await observer.AddItemsAsync(remainingProgressItemCount, cancellationToken).ConfigureAwait(false); - - var metadataService = solution.Services.GetRequiredService(); - var metadataReferences = SemanticSearchUtilities.GetMetadataReferences(metadataService, referenceAssembliesDir); - var queryText = SemanticSearchUtilities.CreateSourceText(query); - var queryCompilation = CreateCompilation(queryText, metadataReferences, solution.Services, out var queryTree, cancellationToken); + var metadataService = services.GetRequiredService(); + var metadataReferences = SemanticSearchUtilities.GetMetadataReferences(metadataService, referenceAssembliesDir); + var queryText = SemanticSearchUtilities.CreateSourceText(query); + var queryCompilation = CreateCompilation(queryText, metadataReferences, services, out var queryTree, cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); + cancellationToken.ThrowIfCancellationRequested(); - // complete compilation progress item: - remainingProgressItemCount--; - await observer.ItemsCompletedAsync(1, cancellationToken).ConfigureAwait(false); + var emitOptions = new EmitOptions( + debugInformationFormat: DebugInformationFormat.PortablePdb, + instrumentationKinds: [InstrumentationKind.StackOverflowProbing, InstrumentationKind.ModuleCancellation]); - var emitOptions = new EmitOptions( - debugInformationFormat: DebugInformationFormat.PortablePdb, - instrumentationKinds: [InstrumentationKind.StackOverflowProbing, InstrumentationKind.ModuleCancellation]); + var peStream = new MemoryStream(); + var pdbStream = new MemoryStream(); - using var peStream = new MemoryStream(); - using var pdbStream = new MemoryStream(); - - var emitDifferenceTimer = SharedStopwatch.StartNew(); - var emitResult = queryCompilation.Emit(peStream, pdbStream, options: emitOptions, cancellationToken: cancellationToken); - var emitTime = emitDifferenceTimer.Elapsed; - - var executionTime = TimeSpan.Zero; + var emitDifferenceTimer = SharedStopwatch.StartNew(); + var emitResult = queryCompilation.Emit(peStream, pdbStream, options: emitOptions, cancellationToken: cancellationToken); + var emitTime = emitDifferenceTimer.Elapsed; - cancellationToken.ThrowIfCancellationRequested(); + CompiledQueryId queryId; + ImmutableArray errors; + if (emitResult.Success) + { + queryId = CompiledQueryId.Create(queryCompilation.Language); + Contract.ThrowIfFalse(ImmutableInterlocked.TryAdd(ref _compiledQueries, queryId, new CompiledQuery(peStream, pdbStream, queryText))); - // complete compilation progress item: - remainingProgressItemCount--; - await observer.ItemsCompletedAsync(1, cancellationToken).ConfigureAwait(false); + errors = []; + } + else + { + queryId = default; - if (!emitResult.Success) + foreach (var diagnostic in emitResult.Diagnostics) { - foreach (var diagnostic in emitResult.Diagnostics) + if (diagnostic.Severity == DiagnosticSeverity.Error) { - if (diagnostic.Severity == DiagnosticSeverity.Error) - { - traceSource.TraceInformation($"Semantic search query compilation failed: {diagnostic}"); - } + traceSource.TraceInformation($"Semantic search query compilation failed: {diagnostic}"); } + } - var errors = emitResult.Diagnostics.SelectAsArray( - d => d.Severity == DiagnosticSeverity.Error, - d => new QueryCompilationError(d.Id, d.GetMessage(), (d.Location.SourceTree == queryTree) ? d.Location.SourceSpan : default)); + errors = emitResult.Diagnostics.SelectAsArray( + d => d.Severity == DiagnosticSeverity.Error, + d => new QueryCompilationError(d.Id, d.GetMessage(), (d.Location.SourceTree == queryTree) ? d.Location.SourceSpan : default)); + } - return CreateResult(errors, FeaturesResources.Semantic_search_query_failed_to_compile); - } + return new CompileQueryResult(queryId, errors, emitTime); + } + + public void DiscardQuery(CompiledQueryId queryId) + { + Contract.ThrowIfFalse(ImmutableInterlocked.TryRemove(ref _compiledQueries, queryId, out var compiledQuery)); + compiledQuery.Dispose(); + } - peStream.Position = 0; - pdbStream.Position = 0; + public async Task ExecuteQueryAsync( + Solution solution, + CompiledQueryId queryId, + ISemanticSearchResultsObserver observer, + OptionsProvider classificationOptions, + TraceSource traceSource, + CancellationToken cancellationToken) + { + Contract.ThrowIfFalse(ImmutableInterlocked.TryRemove(ref _compiledQueries, queryId, out var query)); + + try + { + var executionTime = TimeSpan.Zero; + + var remainingProgressItemCount = solution.ProjectIds.Count; + await observer.AddItemsAsync(remainingProgressItemCount, cancellationToken).ConfigureAwait(false); + + query.PEStream.Position = 0; + query.PdbStream.Position = 0; var loadContext = new LoadContext(); try { - var queryAssembly = loadContext.LoadFromStream(peStream, pdbStream); + var queryAssembly = loadContext.LoadFromStream(query.PEStream, query.PdbStream); SetModuleCancellationToken(queryAssembly, cancellationToken); SetToolImplementations( @@ -146,17 +176,17 @@ public async Task ExecuteQueryAsync( if (!TryGetFindMethod(queryAssembly, out var findMethod, out var queryKind, out var errorMessage, out var errorMessageArgs)) { traceSource.TraceInformation($"Semantic search failed: {errorMessage}"); - return CreateResult(compilationErrors: [], errorMessage, errorMessageArgs); + return CreateResult(errorMessage, errorMessageArgs); } - var invocationContext = new QueryExecutionContext(queryText, findMethod, observer, classificationOptions, traceSource); + var invocationContext = new QueryExecutionContext(query.Text, findMethod, observer, classificationOptions, traceSource); try { await invocationContext.InvokeAsync(solution, queryKind, cancellationToken).ConfigureAwait(false); if (invocationContext.TerminatedWithException) { - return CreateResult(compilationErrors: [], FeaturesResources.Semantic_search_query_terminated_with_exception); + return CreateResult(FeaturesResources.Semantic_search_query_terminated_with_exception); } } finally @@ -176,15 +206,19 @@ public async Task ExecuteQueryAsync( } } - return CreateResult(compilationErrors: [], errorMessage: null); + return CreateResult(errorMessage: null); - ExecuteQueryResult CreateResult(ImmutableArray compilationErrors, string? errorMessage, params string[]? args) - => new(compilationErrors, errorMessage, args, emitTime, executionTime); + ExecuteQueryResult CreateResult(string? errorMessage, params string[]? args) + => new(errorMessage, args, executionTime); } catch (Exception e) when (FatalError.ReportAndPropagateUnlessCanceled(e, cancellationToken, ErrorSeverity.Critical)) { throw ExceptionUtilities.Unreachable(); } + finally + { + query.Dispose(); + } } private static void SetModuleCancellationToken(Assembly queryAssembly, CancellationToken cancellationToken) diff --git a/src/Features/Core/Portable/SemanticSearch/ExecuteQueryResult.cs b/src/Features/Core/Portable/SemanticSearch/ExecuteQueryResult.cs index 7c8a819e2268c..b4defad2404e1 100644 --- a/src/Features/Core/Portable/SemanticSearch/ExecuteQueryResult.cs +++ b/src/Features/Core/Portable/SemanticSearch/ExecuteQueryResult.cs @@ -5,25 +5,59 @@ using System; using System.Collections.Immutable; using System.Runtime.Serialization; +using System.Threading; namespace Microsoft.CodeAnalysis.SemanticSearch; /// /// The result of Semantic Search query execution. /// -/// Compilation errors. /// An error message if the execution failed. /// /// Arguments to be substituted to . /// Use when the values may contain PII that needs to be obscured in telemetry. /// Otherwise, should contain the formatted message. /// -/// Time it took to emit the query compilation. /// Time it took to execute the query. [DataContract] internal readonly record struct ExecuteQueryResult( - [property: DataMember(Order = 0)] ImmutableArray compilationErrors, - [property: DataMember(Order = 1)] string? ErrorMessage, - [property: DataMember(Order = 2)] string[]? ErrorMessageArgs = null, - [property: DataMember(Order = 3)] TimeSpan EmitTime = default, - [property: DataMember(Order = 4)] TimeSpan ExecutionTime = default); + [property: DataMember(Order = 0)] string? ErrorMessage, + [property: DataMember(Order = 1)] string[]? ErrorMessageArgs = null, + [property: DataMember(Order = 2)] TimeSpan ExecutionTime = default); + +/// +/// The result of Semantic Search query compilation. +/// +/// Id of the compiled query if the compilation was successful. +/// Compilation errors. +/// Time it took to emit the query compilation. +[DataContract] +internal readonly record struct CompileQueryResult( + [property: DataMember(Order = 0)] CompiledQueryId QueryId, + [property: DataMember(Order = 1)] ImmutableArray CompilationErrors, + [property: DataMember(Order = 2)] TimeSpan EmitTime = default); + +[DataContract] +internal readonly record struct CompiledQueryId +{ + private static int s_id; + + [DataMember(Order = 0)] +#pragma warning disable IDE0052 // Remove unread private members (https://github.com/dotnet/roslyn/issues/77907) + private readonly int _id; +#pragma warning restore IDE0052 + + [DataMember(Order = 1)] +#pragma warning disable IDE0052 // Remove unread private members (https://github.com/dotnet/roslyn/issues/77907) + public readonly string Language; +#pragma warning restore IDE0052 + + private CompiledQueryId(int id, string language) + { + _id = id; + Language = language; + } + + public static CompiledQueryId Create(string language) + => new(Interlocked.Increment(ref s_id), language); +} diff --git a/src/Features/Core/Portable/SemanticSearch/IRemoteSemanticSearchService.cs b/src/Features/Core/Portable/SemanticSearch/IRemoteSemanticSearchService.cs index 513f95af37dd1..107beb02f6098 100644 --- a/src/Features/Core/Portable/SemanticSearch/IRemoteSemanticSearchService.cs +++ b/src/Features/Core/Portable/SemanticSearch/IRemoteSemanticSearchService.cs @@ -10,6 +10,7 @@ using Microsoft.CodeAnalysis.Classification; using Microsoft.CodeAnalysis.ErrorReporting; using Microsoft.CodeAnalysis.FindUsages; +using Microsoft.CodeAnalysis.Host; using Microsoft.CodeAnalysis.Host.Mef; using Microsoft.CodeAnalysis.Remote; @@ -26,7 +27,9 @@ internal interface ICallback ValueTask ItemsCompletedAsync(RemoteServiceCallbackId callbackId, int itemCount, CancellationToken cancellationToken); } - ValueTask ExecuteQueryAsync(Checksum solutionChecksum, RemoteServiceCallbackId callbackId, string language, string query, string referenceAssembliesDir, CancellationToken cancellationToken); + ValueTask CompileQueryAsync(string query, string language, string referenceAssembliesDir, CancellationToken cancellationToken); + ValueTask ExecuteQueryAsync(Checksum solutionChecksum, RemoteServiceCallbackId callbackId, CompiledQueryId queryId, CancellationToken cancellationToken); + ValueTask DiscardQueryAsync(CompiledQueryId queryId, CancellationToken cancellationToken); } internal static class RemoteSemanticSearchServiceProxy @@ -112,19 +115,41 @@ public async ValueTask GetClassificationOptionsAsync(stri } } - public static async ValueTask ExecuteQueryAsync(Solution solution, string language, string query, string referenceAssembliesDir, ISemanticSearchResultsObserver results, OptionsProvider classificationOptions, CancellationToken cancellationToken) + public static async ValueTask CompileQueryAsync(SolutionServices services, string query, string language, string referenceAssembliesDir, CancellationToken cancellationToken) { - var client = await RemoteHostClient.TryGetClientAsync(solution.Services, cancellationToken).ConfigureAwait(false); + var client = await RemoteHostClient.TryGetClientAsync(services, cancellationToken).ConfigureAwait(false); if (client == null) { - return new ExecuteQueryResult(compilationErrors: [], FeaturesResources.Semantic_search_only_supported_on_net_core); + return null; } + var result = await client.TryInvokeAsync( + (service, cancellationToken) => service.CompileQueryAsync(query, language, referenceAssembliesDir, cancellationToken), + cancellationToken).ConfigureAwait(false); + + return result.Value; + } + + public static async ValueTask DiscardQueryAsync(SolutionServices services, CompiledQueryId queryId, CancellationToken cancellationToken) + { + var client = await RemoteHostClient.TryGetClientAsync(services, cancellationToken).ConfigureAwait(false); + Contract.ThrowIfNull(client); + + await client.TryInvokeAsync( + (service, cancellationToken) => service.DiscardQueryAsync(queryId, cancellationToken), + cancellationToken).ConfigureAwait(false); + } + + public static async ValueTask ExecuteQueryAsync(Solution solution, CompiledQueryId queryId, ISemanticSearchResultsObserver results, OptionsProvider classificationOptions, CancellationToken cancellationToken) + { + var client = await RemoteHostClient.TryGetClientAsync(solution.Services, cancellationToken).ConfigureAwait(false); + Contract.ThrowIfNull(client); + var serverCallback = new ServerCallback(solution, results, classificationOptions); var result = await client.TryInvokeAsync( solution, - (service, solutionInfo, callbackId, cancellationToken) => service.ExecuteQueryAsync(solutionInfo, callbackId, language, query, referenceAssembliesDir, cancellationToken), + (service, solutionInfo, callbackId, cancellationToken) => service.ExecuteQueryAsync(solutionInfo, callbackId, queryId, cancellationToken), callbackTarget: serverCallback, cancellationToken).ConfigureAwait(false); diff --git a/src/Features/Core/Portable/SemanticSearch/ISemanticSearchService.cs b/src/Features/Core/Portable/SemanticSearch/ISemanticSearchService.cs index d1b4257da82fa..3b80cf8d066b5 100644 --- a/src/Features/Core/Portable/SemanticSearch/ISemanticSearchService.cs +++ b/src/Features/Core/Portable/SemanticSearch/ISemanticSearchService.cs @@ -13,21 +13,36 @@ namespace Microsoft.CodeAnalysis.SemanticSearch; internal interface ISemanticSearchService : ILanguageService { /// - /// Executes given query against . + /// Compiles a query. The query has to be executed or discarded. /// - /// The solution snapshot. /// Query (top-level code). /// Directory that contains refernece assemblies to be used for compilation of the query. + CompileQueryResult CompileQuery( + SolutionServices services, + string query, + string referenceAssembliesDir, + TraceSource traceSource, + CancellationToken cancellationToken); + + /// + /// Executes given query against and discards it. + /// + /// The solution snapshot. + /// Id of a compiled query. /// Observer of the found symbols. /// Options to use to classify the textual representation of the found symbols. /// Cancellation token. - /// Error message on failure. Task ExecuteQueryAsync( Solution solution, - string query, - string referenceAssembliesDir, + CompiledQueryId queryId, ISemanticSearchResultsObserver observer, OptionsProvider classificationOptions, TraceSource traceSource, CancellationToken cancellationToken); + + /// + /// Discards resources associated with compiled query. + /// Only call if the query is not executed. + /// + void DiscardQuery(CompiledQueryId queryId); } diff --git a/src/Features/ExternalAccess/Copilot/Internal/SemanticSearch/CopilotSemanticSearchQueryExecutor.cs b/src/Features/ExternalAccess/Copilot/Internal/SemanticSearch/CopilotSemanticSearchQueryExecutor.cs index e4ce02fedc3e9..fbb6d02277719 100644 --- a/src/Features/ExternalAccess/Copilot/Internal/SemanticSearch/CopilotSemanticSearchQueryExecutor.cs +++ b/src/Features/ExternalAccess/Copilot/Internal/SemanticSearch/CopilotSemanticSearchQueryExecutor.cs @@ -78,11 +78,38 @@ public async Task ExecuteAsync(string query, try { - var result = await RemoteSemanticSearchServiceProxy.ExecuteQueryAsync( - _workspace.CurrentSolution, - LanguageNames.CSharp, + var compileResult = await RemoteSemanticSearchServiceProxy.CompileQueryAsync( + _workspace.CurrentSolution.Services, query, + language: LanguageNames.CSharp, SemanticSearchUtilities.ReferenceAssembliesDirectory, + cancellationSource.Token).ConfigureAwait(false); + + if (compileResult == null) + { + return new CopilotSemanticSearchQueryResults() + { + Symbols = observer.Results, + CompilationErrors = [], + Error = FeaturesResources.Semantic_search_only_supported_on_net_core, + LimitReached = false, + }; + } + + if (!compileResult.Value.CompilationErrors.IsEmpty) + { + return new CopilotSemanticSearchQueryResults() + { + Symbols = observer.Results, + CompilationErrors = compileResult.Value.CompilationErrors.SelectAsArray(e => (e.Id, e.Message)), + Error = null, + LimitReached = false, + }; + } + + var executeResult = await RemoteSemanticSearchServiceProxy.ExecuteQueryAsync( + _workspace.CurrentSolution, + compileResult.Value.QueryId, observer, DefaultClassificationOptionsProvider.Instance, cancellationSource.Token).ConfigureAwait(false); @@ -90,8 +117,8 @@ public async Task ExecuteAsync(string query, return new CopilotSemanticSearchQueryResults() { Symbols = observer.Results, - CompilationErrors = result.compilationErrors.SelectAsArray(e => (e.Id, e.Message)), - Error = (result.ErrorMessage != null) ? string.Format(result.ErrorMessage, result.ErrorMessageArgs ?? []) : null, + CompilationErrors = [], + Error = (executeResult.ErrorMessage != null) ? string.Format(executeResult.ErrorMessage, executeResult.ErrorMessageArgs ?? []) : null, LimitReached = false, }; } diff --git a/src/VisualStudio/CSharp/Impl/SemanticSearch/SemanticSearchQueryExecutor.cs b/src/VisualStudio/CSharp/Impl/SemanticSearch/SemanticSearchQueryExecutor.cs index 09550126798c7..49cd41ef12230 100644 --- a/src/VisualStudio/CSharp/Impl/SemanticSearch/SemanticSearchQueryExecutor.cs +++ b/src/VisualStudio/CSharp/Impl/SemanticSearch/SemanticSearchQueryExecutor.cs @@ -64,29 +64,49 @@ public async Task ExecuteAsync(string? query, Document? queryDocument, Solution ExecuteQueryResult result = default; var canceled = false; + var emitTime = TimeSpan.Zero; + try { - result = await RemoteSemanticSearchServiceProxy.ExecuteQueryAsync( - solution, - LanguageNames.CSharp, + var compileResult = await RemoteSemanticSearchServiceProxy.CompileQueryAsync( + solution.Services, query, + language: LanguageNames.CSharp, SemanticSearchUtilities.ReferenceAssembliesDirectory, - resultsObserver, - _classificationOptionsProvider, cancellationToken).ConfigureAwait(false); - foreach (var error in result.compilationErrors) + if (compileResult == null) + { + result = new ExecuteQueryResult(FeaturesResources.Semantic_search_only_supported_on_net_core); + return; + } + + emitTime = compileResult.Value.EmitTime; + + if (!compileResult.Value.CompilationErrors.IsEmpty) { - await presenterContext.OnDefinitionFoundAsync(new SearchCompilationFailureDefinitionItem(error, queryDocument), cancellationToken).ConfigureAwait(false); + foreach (var error in compileResult.Value.CompilationErrors) + { + await presenterContext.OnDefinitionFoundAsync(new SearchCompilationFailureDefinitionItem(error, queryDocument), cancellationToken).ConfigureAwait(false); + } + + return; } + + result = await RemoteSemanticSearchServiceProxy.ExecuteQueryAsync( + solution, + compileResult.Value.QueryId, + resultsObserver, + _classificationOptionsProvider, + cancellationToken).ConfigureAwait(false); } catch (Exception e) when (FatalError.ReportAndPropagateUnlessCanceled(e, cancellationToken, ErrorSeverity.Critical)) { - result = new ExecuteQueryResult(compilationErrors: [], e.Message); + result = new ExecuteQueryResult(e.Message); } catch (OperationCanceledException) { - result = new ExecuteQueryResult(compilationErrors: [], ServicesVSResources.Search_cancelled); + result = new ExecuteQueryResult(ServicesVSResources.Search_cancelled); canceled = true; } finally @@ -110,11 +130,11 @@ await presenterContext.ReportMessageAsync( // Notify the presenter even if the search has been cancelled. await presenterContext.OnCompletedAsync(CancellationToken.None).ConfigureAwait(false); - ReportTelemetry(query, result, canceled); + ReportTelemetry(query, result, emitTime, canceled); } } - private static void ReportTelemetry(string queryString, ExecuteQueryResult result, bool canceled) + private static void ReportTelemetry(string queryString, ExecuteQueryResult result, TimeSpan emitTime, bool canceled) { Logger.Log(FunctionId.SemanticSearch_QueryExecution, KeyValueLogMessage.Create(map => { @@ -135,7 +155,7 @@ private static void ReportTelemetry(string queryString, ExecuteQueryResult resul } map["ExecutionTimeMilliseconds"] = (long)result.ExecutionTime.TotalMilliseconds; - map["EmitTime"] = (long)result.EmitTime.TotalMilliseconds; + map["EmitTime"] = (long)emitTime.TotalMilliseconds; })); } } diff --git a/src/Workspaces/Remote/ServiceHub/Services/SemanticSearch/RemoteSemanticSearchService.cs b/src/Workspaces/Remote/ServiceHub/Services/SemanticSearch/RemoteSemanticSearchService.cs index 7714a0dabd2d5..465b6965a227f 100644 --- a/src/Workspaces/Remote/ServiceHub/Services/SemanticSearch/RemoteSemanticSearchService.cs +++ b/src/Workspaces/Remote/ServiceHub/Services/SemanticSearch/RemoteSemanticSearchService.cs @@ -8,6 +8,7 @@ using Microsoft.CodeAnalysis.Classification; using Microsoft.CodeAnalysis.FindUsages; using Microsoft.CodeAnalysis.SemanticSearch; +using Roslyn.Utilities; namespace Microsoft.CodeAnalysis.Remote; @@ -45,28 +46,59 @@ public ValueTask OnUserCodeExceptionAsync(UserCodeExceptionInfo exception, Cance => callback.InvokeAsync((callback, cancellationToken) => callback.OnUserCodeExceptionAsync(callbackId, exception, cancellationToken), cancellationToken); } + /// + /// Remote API. + /// + public ValueTask CompileQueryAsync( + string query, + string language, + string referenceAssembliesDir, + CancellationToken cancellationToken) + { + return RunServiceAsync(cancellationToken => + { + var services = GetWorkspaceServices(); + var service = services.GetLanguageServices(language).GetRequiredService(); + var result = service.CompileQuery(services, query, referenceAssembliesDir, TraceLogger, cancellationToken); + + return ValueTaskFactory.FromResult(result); + }, cancellationToken); + } + + /// + /// Remote API. + /// + public ValueTask DiscardQueryAsync(CompiledQueryId queryId, CancellationToken cancellationToken) + { + return RunServiceAsync(cancellationToken => + { + var service = GetWorkspaceServices().GetLanguageServices(queryId.Language).GetRequiredService(); + service.DiscardQuery(queryId); + + return default; + }, cancellationToken); + } + /// /// Remote API. /// public ValueTask ExecuteQueryAsync( Checksum solutionChecksum, RemoteServiceCallbackId callbackId, - string language, - string query, - string referenceAssembliesDir, + CompiledQueryId queryId, CancellationToken cancellationToken) { return RunServiceAsync(solutionChecksum, async solution => { - var service = solution.Services.GetLanguageServices(language).GetService(); + var service = solution.Services.GetLanguageServices(queryId.Language).GetService(); if (service == null) { - return new ExecuteQueryResult(compilationErrors: [], FeaturesResources.Semantic_search_only_supported_on_net_core); + return new ExecuteQueryResult(FeaturesResources.Semantic_search_only_supported_on_net_core); } var clientCallbacks = new ClientCallbacks(callback, callbackId); - return await service.ExecuteQueryAsync(solution, query, referenceAssembliesDir, clientCallbacks, clientCallbacks, TraceLogger, cancellationToken).ConfigureAwait(false); + return await service.ExecuteQueryAsync(solution, queryId, clientCallbacks, clientCallbacks, TraceLogger, cancellationToken).ConfigureAwait(false); }, cancellationToken); } }