diff --git a/src/Analysis/Ast/Impl/Analyzer/Symbols/SymbolCollector.cs b/src/Analysis/Ast/Impl/Analyzer/Symbols/SymbolCollector.cs index a58f9f9b8..27f104229 100644 --- a/src/Analysis/Ast/Impl/Analyzer/Symbols/SymbolCollector.cs +++ b/src/Analysis/Ast/Impl/Analyzer/Symbols/SymbolCollector.cs @@ -57,6 +57,12 @@ public override bool Walk(ClassDefinition cd) { if (!string.IsNullOrEmpty(cd.NameExpression?.Name)) { var classInfo = CreateClass(cd); + if (classInfo == null) { + // we can't create class info for this node. + // don't walk down + return false; + } + // The variable is transient (non-user declared) hence it does not have location. // Class type is tracking locations for references and renaming. _eval.DeclareVariable(cd.Name, classInfo, VariableSource.Declaration); @@ -68,7 +74,9 @@ public override bool Walk(ClassDefinition cd) { } public override void PostWalk(ClassDefinition cd) { - if (!IsDeprecated(cd) && !string.IsNullOrEmpty(cd.NameExpression?.Name)) { + if (!IsDeprecated(cd) && + !string.IsNullOrEmpty(cd.NameExpression?.Name) && + _typeMap.ContainsKey(cd)) { _scopes.Pop().Dispose(); } base.PostWalk(cd); @@ -95,9 +103,14 @@ public override void PostWalk(FunctionDefinition fd) { private PythonClassType CreateClass(ClassDefinition cd) { PythonType declaringType = null; - if(!(cd.Parent is PythonAst)) { - Debug.Assert(_typeMap.ContainsKey(cd.Parent)); - _typeMap.TryGetValue(cd.Parent, out declaringType); + if (!(cd.Parent is PythonAst)) { + if (!_typeMap.TryGetValue(cd.Parent, out declaringType)) { + // we can get into this situation if parent is defined twice and we preserve + // only one of them. + // for example, code has function definition with exact same signature + // and class is defined under one of that function + return null; + } } var cls = new PythonClassType(cd, declaringType, _eval.GetLocationOfName(cd), _eval.SuppressBuiltinLookup ? BuiltinTypeId.Unknown : BuiltinTypeId.Type); @@ -120,6 +133,10 @@ private void AddFunction(FunctionDefinition fd, PythonType declaringType) { f = new PythonFunctionType(fd, declaringType, _eval.GetLocationOfName(fd)); // The variable is transient (non-user declared) hence it does not have location. // Function type is tracking locations for references and renaming. + + // if there are multiple functions with same name exist, only the very first one will be + // maintained in the scope. we should improve this if possible. + // https://github.com/microsoft/python-language-server/issues/1693 _eval.DeclareVariable(fd.Name, f, VariableSource.Declaration); _typeMap[fd] = f; declaringType?.AddMember(f.Name, f, overwrite: true); diff --git a/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs b/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs index 5ef5f018a..747adf52b 100644 --- a/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs +++ b/src/Analysis/Ast/Impl/Dependencies/DependencyResolver.cs @@ -23,6 +23,9 @@ namespace Microsoft.Python.Analysis.Dependencies { internal sealed class DependencyResolver : IDependencyResolver { + // optimization to only analyze one that is reachable from root + private readonly bool _checkVertexReachability = true; + private readonly Dictionary _keys = new Dictionary(); private readonly List> _vertices = new List>(); private readonly object _syncObj = new object(); @@ -287,7 +290,7 @@ private bool TryCreateWalkingGraph(in ImmutableArray>.Create(nodesByVertexIndex.Values); return true; + + bool ReachableFromRoot(ImmutableArray reachable, int index) { + const int inaccessibleFromRoot = -1; + + // one of usage case for this optimization is not analyzing module that is not reachable + // from user code + return _checkVertexReachability && reachable[index] != inaccessibleFromRoot; + } } private static ImmutableArray CalculateDepths(in ImmutableArray> vertices) { diff --git a/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs b/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs index a74e221c3..91a034f98 100644 --- a/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs +++ b/src/Analysis/Ast/Impl/Modules/Definitions/IModuleManagement.cs @@ -14,6 +14,8 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; +using System.Threading; using Microsoft.Python.Analysis.Caching; using Microsoft.Python.Analysis.Core.Interpreter; using Microsoft.Python.Analysis.Types; @@ -86,5 +88,7 @@ public interface IModuleManagement : IModuleResolution { ImmutableArray LibraryPaths { get; } bool SetUserConfiguredPaths(ImmutableArray paths); + + IEnumerable GetImportedModules(CancellationToken cancellationToken); } } diff --git a/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs b/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs index 96dc997d9..4329daec3 100644 --- a/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs +++ b/src/Analysis/Ast/Impl/Modules/Resolution/MainModuleResolution.cs @@ -55,6 +55,20 @@ public MainModuleResolution(string root, IServiceContainer services, ImmutableAr public IBuiltinsPythonModule BuiltinsModule { get; private set; } + public IEnumerable GetImportedModules(CancellationToken cancellationToken) { + foreach (var module in _specialized.Values) { + cancellationToken.ThrowIfCancellationRequested(); + yield return module; + } + + foreach (var moduleRef in Modules.Values) { + cancellationToken.ThrowIfCancellationRequested(); + if (moduleRef.Value != null) { + yield return moduleRef.Value; + } + } + } + protected override IPythonModule CreateModule(string name) { var moduleImport = CurrentPathResolver.GetModuleImportFromModuleName(name); if (moduleImport == null) { diff --git a/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs b/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs index 184b2027c..0a3782b24 100644 --- a/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs +++ b/src/Analysis/Ast/Impl/Modules/Resolution/ModuleResolutionBase.cs @@ -81,11 +81,6 @@ public IPythonModule GetOrLoadModule(string name) { return module; } - module = Interpreter.ModuleResolution.GetSpecializedModule(name); - if (module != null) { - return module; - } - // Now try regular case. if (Modules.TryGetValue(name, out var moduleRef)) { return moduleRef.GetOrCreate(name, this); diff --git a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs index a9977c2b5..f97e90212 100644 --- a/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs +++ b/src/Analysis/Core/Impl/DependencyResolution/PathResolverSnapshot.cs @@ -82,15 +82,27 @@ private PathResolverSnapshot(PythonLanguageVersion pythonLanguageVersion, string } public ImmutableArray GetAllImportableModuleNames(bool includeImplicitPackages = true) { + return GetAllImportableModuleInfo(n => !string.IsNullOrEmpty(n.FullModuleName), n => n.FullModuleName, includeImplicitPackages); + } + + public ImmutableArray GetAllImportableModulesByName(string name, bool includeImplicitPackages = true) { + return GetAllImportableModuleInfo(n => string.Equals(n.Name, name), n => n.FullModuleName, includeImplicitPackages); + } + + public ImmutableArray GetAllImportableModuleFilePaths(bool includeImplicitPackages = true) { + return GetAllImportableModuleInfo(n => !string.IsNullOrEmpty(n.ModulePath), n => n.ModulePath, includeImplicitPackages); + } + + private ImmutableArray GetAllImportableModuleInfo(Func predicate, Func valueGetter, bool includeImplicitPackages = true) { var roots = _roots.Prepend(_nonRooted); var items = new Queue(roots); - var names = ImmutableArray.Empty; + var stringValues = ImmutableArray.Empty; while (items.Count > 0) { var item = items.Dequeue(); if (item != null) { - if (!string.IsNullOrEmpty(item.FullModuleName) && (item.IsModule || includeImplicitPackages)) { - names = names.Add(item.FullModuleName); + if (predicate(item) && (item.IsModule || includeImplicitPackages)) { + stringValues = stringValues.Add(valueGetter(item)); } foreach (var child in item.Children.ExcludeDefault()) { @@ -99,13 +111,17 @@ public ImmutableArray GetAllImportableModuleNames(bool includeImplicitPa } } - return names.AddRange( + return stringValues.AddRange( _builtins.Children - .Where(b => !string.IsNullOrEmpty(b.FullModuleName)) - .Select(b => b.FullModuleName) + .Where(b => predicate(b)) + .Select(b => valueGetter(b)) ); } + public string GetModuleNameByPath(string modulePath) { + return TryFindModule(modulePath, out var edge, out _) ? edge.End.FullModuleName : null; + } + public ModuleImport GetModuleImportFromModuleName(in string fullModuleName) { for (var rootIndex = 0; rootIndex < _roots.Count; rootIndex++) { if (TryFindModuleByName(rootIndex, fullModuleName, out var lastEdge) && TryCreateModuleImport(lastEdge, out var moduleImports)) { diff --git a/src/Core/Impl/Extensions/ArrayExtensions.cs b/src/Core/Impl/Extensions/ArrayExtensions.cs index 20c6478b2..ce9185487 100644 --- a/src/Core/Impl/Extensions/ArrayExtensions.cs +++ b/src/Core/Impl/Extensions/ArrayExtensions.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System; +using System.Collections.Generic; namespace Microsoft.Python.Core { public static class ArrayExtensions { @@ -36,5 +37,26 @@ public static int IndexOf(this T[] array, TValue value, Func(this TCollection list, TItem item) + where TCollection : ICollection + where TItem : class { + if (item == null) { + return list; + } + + list.Add(item); + return list; + } + + public static TCollection AddIfNotNull(this TCollection list, params TItem[] items) + where TCollection : ICollection + where TItem : class { + foreach (var item in items) { + list.AddIfNotNull(item); + } + + return list; + } } } diff --git a/src/Core/Impl/Text/Position.cs b/src/Core/Impl/Text/Position.cs index d7ad18649..5f0a4de11 100644 --- a/src/Core/Impl/Text/Position.cs +++ b/src/Core/Impl/Text/Position.cs @@ -18,7 +18,7 @@ namespace Microsoft.Python.Core.Text { [Serializable] - public struct Position { + public struct Position : IEquatable { /// /// Line position in a document (zero-based). /// @@ -39,7 +39,14 @@ public struct Position { public static bool operator >(Position p1, Position p2) => p1.line > p2.line || p1.line == p2.line && p1.character > p2.character; public static bool operator <(Position p1, Position p2) => p1.line < p2.line || p1.line == p2.line && p1.character < p2.character; + public static bool operator ==(Position p1, Position p2) => p1.Equals(p2); + public static bool operator !=(Position p1, Position p2) => !p1.Equals(p2); + public bool Equals(Position other) => line == other.line && character == other.character; + + public override bool Equals(object obj) => obj is Position other ? Equals(other) : false; + + public override int GetHashCode() => 0; public override string ToString() => $"({line}, {character})"; } } diff --git a/src/LanguageServer/Impl/CodeActions/MissingImportCodeActionProvider.cs b/src/LanguageServer/Impl/CodeActions/MissingImportCodeActionProvider.cs new file mode 100644 index 000000000..9ae6c9e64 --- /dev/null +++ b/src/LanguageServer/Impl/CodeActions/MissingImportCodeActionProvider.cs @@ -0,0 +1,727 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Analysis; +using Microsoft.Python.Analysis.Analyzer; +using Microsoft.Python.Analysis.Analyzer.Expressions; +using Microsoft.Python.Analysis.Core.DependencyResolution; +using Microsoft.Python.Analysis.Core.Interpreter; +using Microsoft.Python.Analysis.Diagnostics; +using Microsoft.Python.Analysis.Modules; +using Microsoft.Python.Analysis.Types; +using Microsoft.Python.Analysis.Values; +using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; +using Microsoft.Python.Core.Text; +using Microsoft.Python.LanguageServer.Diagnostics; +using Microsoft.Python.LanguageServer.Indexing; +using Microsoft.Python.LanguageServer.Protocol; +using Microsoft.Python.LanguageServer.Utilities; +using Microsoft.Python.Parsing.Ast; +using Range = Microsoft.Python.Core.Text.Range; + +namespace Microsoft.Python.LanguageServer.CodeActions { + internal sealed class MissingImportCodeActionProvider : ICodeActionProvider { + public static readonly ICodeActionProvider Instance = new MissingImportCodeActionProvider(); + + // right now, it is a static. in future, we might consider giving an option to users to customize this list + // also, right now, it is text based. so if module has same name, they will get same suggestion even if + // the module is not something user expected + private static readonly Dictionary WellKnownAbbreviationMap = new Dictionary() { + { "numpy", "np" }, + { "pandas", "pd" }, + { "tensorflow", "tf" }, + { "matplotlib.pyplot", "plt" }, + { "matplotlib", "mpl" }, + { "math", "m" }, + { "scipy.io", "spio" }, + { "scipy", "sp" }, + }; + + private MissingImportCodeActionProvider() { + } + + public ImmutableArray FixableDiagnostics => ImmutableArray.Create( + ErrorCodes.UndefinedVariable, ErrorCodes.VariableNotDefinedGlobally, ErrorCodes.VariableNotDefinedNonLocal); + + public async Task> GetCodeActionsAsync(IDocumentAnalysis analysis, DiagnosticsEntry diagnostic, CancellationToken cancellationToken) { + var finder = new ExpressionFinder(analysis.Ast, new FindExpressionOptions() { Names = true }); + var node = finder.GetExpression(diagnostic.SourceSpan); + if (!(node is NameExpression nex)) { + return Enumerable.Empty(); + } + + var identifier = nex.Name; + if (string.IsNullOrEmpty(identifier)) { + return Enumerable.Empty(); + } + + var codeActions = new List(); + var diagnostics = new[] { diagnostic.ToDiagnostic() }; + + // see whether it is one of abbreviation we specialize + foreach (var moduleFullName in WellKnownAbbreviationMap.Where(kv => kv.Value == identifier).Select(kv => kv.Key)) { + var moduleName = GetModuleName(moduleFullName); + + await GetCodeActionsAsync(analysis, diagnostics, new Input(node, moduleName, moduleFullName), codeActions, cancellationToken); + } + + // add then search given name as it is + await GetCodeActionsAsync(analysis, diagnostics, new Input(node, identifier), codeActions, cancellationToken); + + return codeActions; + + string GetModuleName(string moduleFullName) { + var index = moduleFullName.LastIndexOf("."); + return index < 0 ? moduleFullName : moduleFullName.Substring(index + 1); + } + } + + private async Task GetCodeActionsAsync(IDocumentAnalysis analysis, + Diagnostic[] diagnostics, + Input input, + List codeActions, + CancellationToken cancellationToken) { + var importFullNameMap = new Dictionary(); + await AddCandidatesFromIndexAsync(analysis, input.Identifier, importFullNameMap, cancellationToken); + + var interpreter = analysis.Document.Interpreter; + var pathResolver = interpreter.ModuleResolution.CurrentPathResolver; + + // find installed modules matching the given name. this will include submodules + var languageVersion = Parsing.PythonLanguageVersionExtensions.ToVersion(interpreter.LanguageVersion); + var includeImplicit = !ModulePath.PythonVersionRequiresInitPyFiles(languageVersion); + + foreach (var moduleFullName in pathResolver.GetAllImportableModulesByName(input.Identifier, includeImplicit)) { + cancellationToken.ThrowIfCancellationRequested(); + importFullNameMap[moduleFullName] = new ImportInfo(moduleImported: false, memberImported: false, isModule: true); + } + + // find members matching the given name from modules already loaded. + var moduleInfo = new ModuleInfo(analysis); + foreach (var module in interpreter.ModuleResolution.GetImportedModules(cancellationToken)) { + if (module.ModuleType == ModuleType.Unresolved) { + continue; + } + + // module name is full module name that you can use in import xxxx directly + CollectCandidates(moduleInfo.Reset(module), input.Identifier, importFullNameMap, cancellationToken); + Debug.Assert(moduleInfo.NameParts.Count == 1 && moduleInfo.NameParts[0] == module.Name); + } + + // check quick bail out case where we know what module we are looking for + if (input.ModuleFullNameOpt != null) { + if (importFullNameMap.ContainsKey(input.ModuleFullNameOpt)) { + // add code action if the module exist, otherwise, bail out empty + codeActions.AddIfNotNull(CreateCodeAction(analysis, input.Context, input.ModuleFullNameOpt, diagnostics, locallyInserted: false, cancellationToken)); + } + return; + } + + // regular case + FilterCandidatesBasedOnContext(analysis, input.Context, importFullNameMap, cancellationToken); + + // this will create actual code fix with certain orders + foreach (var fullName in OrderFullNames(importFullNameMap)) { + cancellationToken.ThrowIfCancellationRequested(); + codeActions.AddIfNotNull(CreateCodeAction(analysis, input.Context, fullName, diagnostics, locallyInserted: false, cancellationToken)); + } + } + + private void FilterCandidatesBasedOnContext(IDocumentAnalysis analysis, Node node, Dictionary importFullNameMap, CancellationToken cancellationToken) { + var ancestors = GetAncestorsOrThis(analysis.Ast.Body, node, cancellationToken); + var index = ancestors.LastIndexOf(node); + if (index <= 0) { + // nothing to filter on + return; + } + + var parent = ancestors[index - 1]; + if (!(parent is CallExpression)) { + // nothing to filter on + return; + } + + // do simple filtering + // remove all modules from candidates + foreach (var kv in importFullNameMap.ToList()) { + if (kv.Value.IsModule) { + importFullNameMap.Remove(kv.Key); + } + } + } + + private IEnumerable OrderFullNames(Dictionary importFullNameMap) { + // use some heuristic to improve code fix ordering + + // put simple name module at the top + foreach (var fullName in OrderImportNames(importFullNameMap.Where(FilterSimpleName).Select(kv => kv.Key))) { + importFullNameMap.Remove(fullName); + yield return fullName; + } + + // heuristic is we put entries with decl without any exports (imported member with __all__) at the top + // such as array. another example will be chararray. + // this will make numpy chararray at the top and numpy defchararray at the bottom. + // if we want, we can add more info to hide intermediate ones. + // for example, numpy.chararry is __all__.extended from numpy.core.chararray and etc. + // so we could leave only numpy.chararray and remove ones like numpy.core.chararray and etc. but for now, + // we show all those but in certain order so that numpy.chararray shows up top + // this heuristic still has issue with something like os.path.join since no one import macpath, macpath join shows up high + var sourceDeclarationFullNames = importFullNameMap.Where(kv => kv.Value.Symbol != null) + .GroupBy(kv => kv.Value.Symbol.Definition, LocationInfo.FullComparer) + .Where(FilterSourceDeclarations) + .Select(g => g.First().Key); + + foreach (var fullName in OrderImportNames(sourceDeclarationFullNames)) { + importFullNameMap.Remove(fullName); + yield return fullName; + } + + // put modules that are imported next + foreach (var fullName in OrderImportNames(importFullNameMap.Where(FilterImportedModules).Select(kv => kv.Key))) { + importFullNameMap.Remove(fullName); + yield return fullName; + } + + // put members that are imported next + foreach (var fullName in OrderImportNames(importFullNameMap.Where(FilterImportedMembers).Select(kv => kv.Key))) { + importFullNameMap.Remove(fullName); + yield return fullName; + } + + // put members whose module is imported next + foreach (var fullName in OrderImportNames(importFullNameMap.Where(FilterImportedModuleMembers).Select(kv => kv.Key))) { + importFullNameMap.Remove(fullName); + yield return fullName; + } + + // put things left here. + foreach (var fullName in OrderImportNames(importFullNameMap.Select(kv => kv.Key))) { + yield return fullName; + } + + List OrderImportNames(IEnumerable fullNames) { + return fullNames.OrderBy(n => n, ImportNameComparer.Instance).ToList(); + } + + bool FilterSimpleName(KeyValuePair kv) => kv.Key.IndexOf(".") < 0; + bool FilterImportedMembers(KeyValuePair kv) => !kv.Value.IsModule && kv.Value.MemberImported; + bool FilterImportedModuleMembers(KeyValuePair kv) => !kv.Value.IsModule && kv.Value.ModuleImported; + bool FilterImportedModules(KeyValuePair kv) => kv.Value.IsModule && kv.Value.MemberImported; + + bool FilterSourceDeclarations(IGrouping> group) { + var count = 0; + foreach (var entry in group) { + if (count++ > 0) { + return false; + } + + var value = entry.Value; + if (value.ModuleImported || value.ModuleImported) { + return false; + } + } + + return true; + } + } + + private static async Task AddCandidatesFromIndexAsync(IDocumentAnalysis analysis, + string name, + Dictionary importFullNameMap, + CancellationToken cancellationToken) { + var indexManager = analysis.ExpressionEvaluator.Services.GetService(); + if (indexManager == null) { + // indexing is not supported + return; + } + + var symbolsIncludingName = await indexManager.WorkspaceSymbolsAsync(name, maxLength: int.MaxValue, includeLibraries: true, cancellationToken); + + // we only consider exact matches rather than partial matches + var symbolsWithName = symbolsIncludingName.Where(Include); + + var analyzer = analysis.ExpressionEvaluator.Services.GetService(); + var pathResolver = analysis.Document.Interpreter.ModuleResolution.CurrentPathResolver; + + var modules = ImmutableArray.Empty; + foreach (var symbolAndModuleName in symbolsWithName.Select(s => (symbol: s, moduleName: pathResolver.GetModuleNameByPath(s.DocumentPath)))) { + cancellationToken.ThrowIfCancellationRequested(); + + var key = $"{symbolAndModuleName.moduleName}.{symbolAndModuleName.symbol.Name}"; + var symbol = symbolAndModuleName.symbol; + + importFullNameMap.TryGetValue(key, out var existing); + + // we don't actually know whether this is a module. all we know is it appeared at + // Import statement. but most likely module, so we mark it as module for now. + // later when we check loaded module, if this happen to be loaded, this will get + // updated with more accurate data. + // if there happen to be multiple symbols with same name, we refer to mark it as module + var isModule = symbol.Kind == Indexing.SymbolKind.Module || existing.IsModule; + + // any symbol marked "Module" by indexer is imported. + importFullNameMap[key] = new ImportInfo( + moduleImported: isModule, + memberImported: isModule, + isModule); + } + + bool Include(FlatSymbol symbol) { + // we only suggest symbols that exist in __all__ + // otherwise, we show gigantic list from index + return symbol._existInAllVariable && + symbol.ContainerName == null && + CheckKind(symbol.Kind) && + symbol.Name == name; + } + + bool CheckKind(Indexing.SymbolKind kind) { + switch (kind) { + case Indexing.SymbolKind.Module: + case Indexing.SymbolKind.Namespace: + case Indexing.SymbolKind.Package: + case Indexing.SymbolKind.Class: + case Indexing.SymbolKind.Enum: + case Indexing.SymbolKind.Interface: + case Indexing.SymbolKind.Function: + case Indexing.SymbolKind.Constant: + case Indexing.SymbolKind.Struct: + return true; + default: + return false; + } + } + } + + private CodeAction CreateCodeAction(IDocumentAnalysis analysis, + Node node, + string moduleFullName, + Diagnostic[] diagnostics, + bool locallyInserted, + CancellationToken cancellationToken) { + var insertionPoint = GetInsertionInfo(analysis, node, moduleFullName, locallyInserted, cancellationToken); + if (insertionPoint == null) { + return null; + } + + var insertionText = insertionPoint.Value.InsertionText; + var titleText = locallyInserted ? Resources.ImportLocally.FormatUI(insertionText) : insertionText; + + var sb = new StringBuilder(); + sb.AppendIf(insertionPoint.Value.Range.start == insertionPoint.Value.Range.end, insertionPoint.Value.Indentation); + sb.Append(insertionPoint.Value.AddBlankLine ? insertionText + Environment.NewLine : insertionText); + sb.AppendIf(insertionPoint.Value.Range.start == insertionPoint.Value.Range.end, Environment.NewLine); + + var textEdits = new List(); + textEdits.Add(new TextEdit() { range = insertionPoint.Value.Range, newText = sb.ToString() }); + + if (insertionPoint.Value.AbbreviationOpt != null) { + textEdits.Add(new TextEdit() { range = node.GetSpan(analysis.Ast), newText = insertionPoint.Value.AbbreviationOpt }); + } + + var changes = new Dictionary { { analysis.Document.Uri, textEdits.ToArray() } }; + return new CodeAction() { title = titleText, kind = CodeActionKind.QuickFix, diagnostics = diagnostics, edit = new WorkspaceEdit() { changes = changes } }; + } + + private InsertionInfo? GetInsertionInfo(IDocumentAnalysis analysis, + Node node, + string fullyQualifiedName, + bool locallyInserted, + CancellationToken cancellationToken) { + var (body, indentation) = GetStartingPoint(analysis, node, locallyInserted, cancellationToken); + if (body == null) { + // no insertion point + return null; + } + + var importNodes = body.GetChildNodes().Where(c => c is ImportStatement || c is FromImportStatement).ToList(); + var lastImportNode = importNodes.LastOrDefault(); + + var abbreviation = GetAbbreviationForWellKnownModules(analysis, fullyQualifiedName); + + // first check whether module name is dotted or not + var dotIndex = fullyQualifiedName.LastIndexOf('.'); + if (dotIndex < 0) { + // there can't be existing import since we have the error + return new InsertionInfo(addBlankLine: lastImportNode == null, + GetInsertionText($"import {fullyQualifiedName}", abbreviation), + GetRange(analysis.Ast, body, lastImportNode), + indentation, + abbreviation); + } + + // see whether there is existing from * import * statement. + var fromPart = fullyQualifiedName.Substring(startIndex: 0, dotIndex); + var nameToAdd = fullyQualifiedName.Substring(dotIndex + 1); + foreach (var current in importNodes.Reverse().OfType()) { + if (current.Root.MakeString() == fromPart) { + return new InsertionInfo(addBlankLine: false, + GetInsertionText(current, fromPart, nameToAdd, abbreviation), + current.GetSpan(analysis.Ast), + indentation, + abbreviation); + } + } + + // add new from * import * statement + return new InsertionInfo(addBlankLine: lastImportNode == null, + GetInsertionText($"from {fromPart} import {nameToAdd}", abbreviation), + GetRange(analysis.Ast, body, lastImportNode), + indentation, + abbreviation); + } + + private static string GetAbbreviationForWellKnownModules(IDocumentAnalysis analysis, string fullyQualifiedName) { + if (WellKnownAbbreviationMap.TryGetValue(fullyQualifiedName, out var abbreviation)) { + // for now, use module wide unique name for abbreviation. even though technically we could use + // context based unique name since variable declared in lower scope will hide it and there is no conflict + return UniqueNameGenerator.Generate(analysis, abbreviation); + } + + return null; + } + + private static string GetInsertionText(string insertionText, string abbreviation) => + abbreviation == null ? insertionText : $"{insertionText} as {abbreviation}"; + + private string GetInsertionText(FromImportStatement fromImportStatement, string rootModuleName, string moduleNameToAdd, string abbreviation) { + var imports = fromImportStatement.Names.Select(n => n.Name) + .Concat(new string[] { GetInsertionText(moduleNameToAdd, abbreviation) }) + .OrderBy(n => n).ToList(); + + return $"from {rootModuleName} import {string.Join(", ", imports)}"; + } + + private Range GetRange(PythonAst ast, Statement body, Node lastImportNode) { + var position = GetPosition(ast, body, lastImportNode); + return new Range() { start = position, end = position }; + } + + private Position GetPosition(PythonAst ast, Statement body, Node lastImportNode) { + if (lastImportNode != null) { + var endLocation = lastImportNode.GetEnd(ast); + return new Position { line = endLocation.Line, character = 0 }; + } + + // firstNode must exist in this context + var firstNode = body.GetChildNodes().First(); + return new Position() { line = firstNode.GetStart(ast).Line - 1, character = 0 }; + } + + private (Statement body, string indentation) GetStartingPoint(IDocumentAnalysis analysis, + Node node, + bool locallyInserted, + CancellationToken cancellationToken) { + if (!locallyInserted) { + return (analysis.Ast.Body, string.Empty); + } + + var candidate = GetAncestorsOrThis(analysis.Ast.Body, node, cancellationToken).Where(p => p is FunctionDefinition).LastOrDefault(); + + // for now, only stop at FunctionDefinition. + // we can expand it to more scope if we want but this seems what other tool also provide as well. + // this will return closest scope from given node + switch (candidate) { + case FunctionDefinition functionDefinition: + return (functionDefinition.Body, GetIndentation(analysis.Ast, functionDefinition.Body)); + default: + // no local scope + return default; + } + } + + private string GetIndentation(PythonAst ast, Statement body) { + // first token must exist in current context + var firstToken = body.GetChildNodes().First(); + + // not sure how to handle a case where user is using "tab" instead of "space" + // for indentation. where can one get tab over indentation option? + return new string(' ', firstToken.GetStart(ast).Column - 1); + } + + private List GetAncestorsOrThis(Node root, Node node, CancellationToken cancellationToken) { + var parentChain = new List(); + + // there seems no way to go up the parent chain. always has to go down from the top + while (root != null) { + cancellationToken.ThrowIfCancellationRequested(); + + var temp = root; + root = null; + + // this assumes node is not overlapped and children are ordered from left to right + // in textual position + foreach (var current in temp.GetChildNodes()) { + if (!current.IndexSpan.Contains(node.IndexSpan)) { + continue; + } + + parentChain.Add(current); + root = current; + break; + } + } + + return parentChain; + } + + private void CollectCandidates(ModuleInfo moduleInfo, + string name, + Dictionary importFullNameMap, + CancellationToken cancellationToken) { + if (!moduleInfo.CheckCircularImports()) { + // bail out on circular imports + return; + } + + // add non module (imported) member + AddNonImportedMemberWithName(moduleInfo, name, importFullNameMap); + + // add module (imported) members if it shows up in __all__ + // + // we are doing recursive dig down rather than just going through all modules loaded linearly + // since path to how to get to a module is important. + // for example, "join" is defined in "ntpath" or "macpath" and etc, but users are supposed to + // use it through "os.path" which will automatically point to right module ex, "ntpath" based on + // environment rather than "ntpath" directly. if we just go through module in flat list, then + // we can miss "os.path" since it won't show in the module list. + // for these modules that are supposed to be used with indirect path (imported name of the module), + // we need to dig down to collect those with right path. + foreach (var memberName in GetAllVariables(moduleInfo.Analysis)) { + cancellationToken.ThrowIfCancellationRequested(); + + var pythonModule = moduleInfo.Module.GetMember(memberName) as IPythonModule; + if (pythonModule == null) { + continue; + } + + var fullName = $"{moduleInfo.FullName}.{memberName}"; + if (string.Equals(memberName, name)) { + // nested module are all imported + AddNameParts(fullName, moduleImported: true, memberImported: true, pythonModule, importFullNameMap); + } + + // make sure we dig down modules only if we can use it from imports + // for example, user can do "from numpy import char" to import char [defchararray] module + // but user can not do "from numpy.char import x" since it is not one of known modules to us. + // in contrast, users can do "from os import path" to import path [ntpath] module + // but also can do "from os.path import x" since "os.path" is one of known moudles to us. + var result = AstUtilities.FindImports( + moduleInfo.CurrentFileAnalysis.Document.Interpreter.ModuleResolution.CurrentPathResolver, + moduleInfo.CurrentFileAnalysis.Document.FilePath, + GetRootNames(fullName), + dotCount: 0, + forceAbsolute: true); + + if (result is ImportNotFound) { + continue; + } + + moduleInfo.AddName(memberName); + CollectCandidates(moduleInfo.With(pythonModule), name, importFullNameMap, cancellationToken); + moduleInfo.PopName(); + } + + // pop this module out so we can get to this module from + // different path. + // ex) A -> B -> [C] and A -> D -> [C] + moduleInfo.ForgetModule(); + } + + private IEnumerable GetRootNames(string fullName) { + return fullName.Split('.'); + } + + private void AddNonImportedMemberWithName(ModuleInfo moduleInfo, string name, Dictionary importFullNameMap) { + // for now, skip any protected or private member + if (name.StartsWith("_")) { + return; + } + + var pythonType = moduleInfo.Module.GetMember(name); + if (pythonType == null || pythonType is IPythonModule || pythonType.IsUnknown()) { + return; + } + + // skip any imported member (non module member) unless it is explicitly on __all__ + if (moduleInfo.Analysis.GlobalScope.Imported.TryGetVariable(name, out var importedVariable) && + object.Equals(pythonType, importedVariable.Value) && + GetAllVariables(moduleInfo.Analysis).All(s => !string.Equals(s, name))) { + return; + } + + moduleInfo.AddName(name); + AddNameParts(moduleInfo.FullName, moduleInfo.ModuleImported, importedVariable != null, pythonType, importFullNameMap); + moduleInfo.PopName(); + } + + private static void AddNameParts( + string fullName, bool moduleImported, bool memberImported, IPythonType symbol, Dictionary moduleFullNameMap) { + // one of case this can happen is if module's fullname is "a.b.c" and module "a.b" also import module "a.b.c" as "c" making + // fullname same "a.b.c". in this case, we mark it as "imported" since we refer one explicily shown in "__all__" to show + // higher rank than others + if (moduleFullNameMap.TryGetValue(fullName, out var info)) { + moduleImported |= info.ModuleImported; + } + + moduleFullNameMap[fullName] = new ImportInfo(moduleImported, memberImported, symbol); + } + + private IEnumerable GetAllVariables(IDocumentAnalysis analysis) { + // this is different than StartImportMemberNames since that only returns something when + // all entries are known. for import, we are fine doing best effort + if (analysis.GlobalScope.Variables.TryGetVariable("__all__", out var variable) && + variable?.Value is IPythonCollection collection) { + return collection.Contents + .OfType() + .Select(c => c.GetString()) + .Where(s => !string.IsNullOrEmpty(s)); + } + + return Array.Empty(); + } + + private class ImportNameComparer : IComparer { + public static readonly ImportNameComparer Instance = new ImportNameComparer(); + + private ImportNameComparer() { } + + public int Compare(string x, string y) { + const string underscore = "_"; + + // move "_" to back of the list + if (x.StartsWith(underscore) && y.StartsWith(underscore)) { + return x.CompareTo(y); + } + if (x.StartsWith(underscore)) { + return 1; + } + if (y.StartsWith(underscore)) { + return -1; + } + + return x.CompareTo(y); + } + } + + private struct InsertionInfo { + public readonly bool AddBlankLine; + public readonly string InsertionText; + public readonly Range Range; + public readonly string Indentation; + public readonly string AbbreviationOpt; + + public InsertionInfo(bool addBlankLine, string insertionText, Range range, string indentation, string abbreviationOpt = null) { + AddBlankLine = addBlankLine; + InsertionText = insertionText; + Range = range; + Indentation = indentation; + AbbreviationOpt = abbreviationOpt; + } + } + + private struct Input { + public readonly Node Context; + public readonly string Identifier; + public readonly string ModuleFullNameOpt; + + public Input(Node context, string identifier, string moduleFullNameOpt = null) { + Context = context; + Identifier = identifier; + ModuleFullNameOpt = moduleFullNameOpt; + } + } + + private struct ModuleInfo { + public readonly IDocumentAnalysis CurrentFileAnalysis; + public readonly IPythonModule Module; + public readonly List NameParts; + public readonly bool ModuleImported; + + private readonly HashSet _visited; + + public IDocumentAnalysis Analysis => Module.Analysis; + public string FullName => string.Join('.', NameParts); + + public ModuleInfo(IDocumentAnalysis document) : + this(document, module: null, new List(), moduleImported: false) { + } + + private ModuleInfo(IDocumentAnalysis document, IPythonModule module, List nameParts, bool moduleImported) : + this() { + CurrentFileAnalysis = document; + Module = module; + NameParts = nameParts; + ModuleImported = moduleImported; + + _visited = new HashSet(); + } + + public bool CheckCircularImports() => Module != null && _visited.Add(Module); + public void ForgetModule() => _visited.Remove(Module); + + public void AddName(string memberName) => NameParts.Add(memberName); + public void PopName() => NameParts.RemoveAt(NameParts.Count - 1); + + public ModuleInfo With(IPythonModule module) { + return new ModuleInfo(CurrentFileAnalysis, module, NameParts, moduleImported: true); + } + + public ModuleInfo Reset(IPythonModule module) { + Debug.Assert(_visited.Count == 0); + + NameParts.Clear(); + NameParts.Add(module.Name); + + return new ModuleInfo(CurrentFileAnalysis, module, NameParts, moduleImported: false); + } + } + + [DebuggerDisplay("{Symbol?.Name} Module:{IsModule} ({ModuleImported} {MemberImported})")] + private struct ImportInfo { + // only one that shows up in "__all__" will be imported + // containing module is imported + public readonly bool ModuleImported; + // containing symbol is imported + public readonly bool MemberImported; + + public readonly bool IsModule; + public readonly IPythonType Symbol; + + public ImportInfo(bool moduleImported, bool memberImported, IPythonType symbol) : + this(moduleImported, memberImported, symbol.MemberType == PythonMemberType.Module) { + Symbol = symbol; + } + + public ImportInfo(bool moduleImported, bool memberImported, bool isModule) { + ModuleImported = moduleImported; + MemberImported = memberImported; + IsModule = isModule; + Symbol = null; + } + } + } +} + diff --git a/src/LanguageServer/Impl/Definitions/ICodeActionProvider.cs b/src/LanguageServer/Impl/Definitions/ICodeActionProvider.cs new file mode 100644 index 000000000..fc500bb54 --- /dev/null +++ b/src/LanguageServer/Impl/Definitions/ICodeActionProvider.cs @@ -0,0 +1,42 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Analysis; +using Microsoft.Python.Analysis.Diagnostics; +using Microsoft.Python.Core.Collections; +using Microsoft.Python.LanguageServer.Protocol; + +namespace Microsoft.Python.LanguageServer { + public interface ICodeActionProvider { + /// + /// Returns error code this code action can provide fix for. this error code must be same as ones that are reported to host as diagnostics + /// ex) error code from linter + /// + ImmutableArray FixableDiagnostics { get; } + + /// + /// Returns that can potentially fix given diagnostic + /// + /// of the file where reported + /// that code action is supposed to fix + /// + /// that can fix the given + Task> GetCodeActionsAsync(IDocumentAnalysis analysis, DiagnosticsEntry diagnostic, CancellationToken cancellation); + } +} diff --git a/src/LanguageServer/Impl/Diagnostics/DiagnosticExtensions.cs b/src/LanguageServer/Impl/Diagnostics/DiagnosticExtensions.cs new file mode 100644 index 000000000..6a13be8cf --- /dev/null +++ b/src/LanguageServer/Impl/Diagnostics/DiagnosticExtensions.cs @@ -0,0 +1,45 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using Microsoft.Python.Analysis.Diagnostics; +using Microsoft.Python.LanguageServer.Protocol; +using Microsoft.Python.Parsing; + +namespace Microsoft.Python.LanguageServer.Diagnostics { + internal static class DiagnosticExtensions { + public static Diagnostic ToDiagnostic(this DiagnosticsEntry diagnostic, string source = "Python") { + return new Diagnostic { + range = diagnostic.SourceSpan, + severity = diagnostic.Severity.ToDiagnosticSeverity(), + source = source, + code = diagnostic.ErrorCode, + message = diagnostic.Message, + }; + } + + public static DiagnosticSeverity ToDiagnosticSeverity(this Severity severity) { + switch (severity) { + case Severity.Warning: + return DiagnosticSeverity.Warning; + case Severity.Information: + return DiagnosticSeverity.Information; + case Severity.Hint: + return DiagnosticSeverity.Hint; + default: + return DiagnosticSeverity.Error; + } + } + } +} diff --git a/src/LanguageServer/Impl/Diagnostics/DiagnosticsService.cs b/src/LanguageServer/Impl/Diagnostics/DiagnosticsService.cs index 64924a111..3fd041b16 100644 --- a/src/LanguageServer/Impl/Diagnostics/DiagnosticsService.cs +++ b/src/LanguageServer/Impl/Diagnostics/DiagnosticsService.cs @@ -154,7 +154,7 @@ private void PublishDiagnostics() { var parameters = new PublishDiagnosticsParams { uri = uri, diagnostics = Rdt.GetDocument(uri)?.IsOpen == true - ? FilterBySeverityMap(documentDiagnostics).Select(ToDiagnostic).ToArray() + ? FilterBySeverityMap(documentDiagnostics).Select(d => d.ToDiagnostic()).ToArray() : Array.Empty() }; _clientApp.NotifyWithParameterObjectAsync("textDocument/publishDiagnostics", parameters).DoNotWait(); @@ -168,32 +168,6 @@ private void ClearAllDiagnostics() { } } - private static Diagnostic ToDiagnostic(DiagnosticsEntry e) { - DiagnosticSeverity s; - switch (e.Severity) { - case Severity.Warning: - s = DiagnosticSeverity.Warning; - break; - case Severity.Information: - s = DiagnosticSeverity.Information; - break; - case Severity.Hint: - s = DiagnosticSeverity.Hint; - break; - default: - s = DiagnosticSeverity.Error; - break; - } - - return new Diagnostic { - range = e.SourceSpan, - severity = s, - source = "Python", - code = e.ErrorCode, - message = e.Message, - }; - } - private IEnumerable FilterBySeverityMap(DocumentDiagnostics d) => d.Entries .SelectMany(kvp => kvp.Value) diff --git a/src/LanguageServer/Impl/Implementation/Server.Editor.cs b/src/LanguageServer/Impl/Implementation/Server.Editor.cs index 0898c023c..c593eb69c 100644 --- a/src/LanguageServer/Impl/Implementation/Server.Editor.cs +++ b/src/LanguageServer/Impl/Implementation/Server.Editor.cs @@ -98,7 +98,7 @@ public async Task GotoDeclaration(TextDocumentPositionParams @params, var analysis = await Document.GetAnalysisAsync(uri, Services, CompletionAnalysisTimeout, cancellationToken); var reference = new DeclarationSource(Services).FindDefinition(analysis, @params.position, out _); - return reference != null ? new Location { uri = reference.uri, range = reference.range} : null; + return reference != null ? new Location { uri = reference.uri, range = reference.range } : null; } public Task FindReferences(ReferencesParams @params, CancellationToken cancellationToken) { @@ -112,5 +112,18 @@ public Task Rename(RenameParams @params, CancellationToken cancel _log?.Log(TraceEventType.Verbose, $"Rename in {uri} at {@params.position}"); return new RenameSource(Services).RenameAsync(uri, @params.position, @params.newName, cancellationToken); } + + public async Task CodeAction(CodeActionParams @params, CancellationToken cancellationToken) { + var uri = @params.textDocument.uri; + _log?.Log(TraceEventType.Verbose, $"Code Action in {uri} at {@params.range}"); + + if (@params.context.diagnostics?.Length == 0) { + return Array.Empty(); + } + + var analysis = await Document.GetAnalysisAsync(uri, Services, CompletionAnalysisTimeout, cancellationToken); + var codeActions = await new CodeActionSource(Services).GetCodeActionsAsync(analysis, @params.context.diagnostics, cancellationToken); + return codeActions ?? Array.Empty(); + } } } diff --git a/src/LanguageServer/Impl/Implementation/Server.cs b/src/LanguageServer/Impl/Implementation/Server.cs index 4a553a2f8..6e155e29f 100644 --- a/src/LanguageServer/Impl/Implementation/Server.cs +++ b/src/LanguageServer/Impl/Implementation/Server.cs @@ -98,6 +98,7 @@ public Server(IServiceManager services) { firstTriggerCharacter = "\n", moreTriggerCharacter = new[] { ";", ":" } }, + codeActionProvider = new CodeActionOptions() { codeActionKinds = new string[] { CodeActionKind.QuickFix } }, } }; @@ -160,7 +161,7 @@ public async Task InitializedAsync(InitializedParams @params, CancellationToken initializationOptions?.includeFiles, initializationOptions?.excludeFiles, _services.GetService()); - _indexManager.IndexWorkspace().DoNotWait(); + _indexManager.IndexWorkspace(_interpreter.ModuleResolution.CurrentPathResolver).DoNotWait(); _services.AddService(_indexManager); _disposableBag.Add(_indexManager); @@ -190,12 +191,11 @@ public Task Shutdown() { public void DidChangeConfiguration(DidChangeConfigurationParams @params, CancellationToken cancellationToken) { _disposableBag.ThrowIfDisposed(); switch (@params.settings) { - case ServerSettings settings: { - Settings = settings; - _symbolHierarchyMaxSymbols = Settings.analysis.symbolsHierarchyMaxSymbols; - _completionSource.Options = Settings.completion; - break; - } + case ServerSettings settings: + Settings = settings; + _symbolHierarchyMaxSymbols = Settings.analysis.symbolsHierarchyMaxSymbols; + _completionSource.Options = Settings.completion; + break; default: _log?.Log(TraceEventType.Error, "change configuration notification sent unsupported settings"); break; diff --git a/src/LanguageServer/Impl/Indexing/AllVariableCollector.cs b/src/LanguageServer/Impl/Indexing/AllVariableCollector.cs new file mode 100644 index 000000000..0e1137816 --- /dev/null +++ b/src/LanguageServer/Impl/Indexing/AllVariableCollector.cs @@ -0,0 +1,107 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Microsoft.Python.Parsing.Ast; + +namespace Microsoft.Python.LanguageServer.Indexing { + /// + /// This is a poor man's __all__ values collector. it uses only syntactic information to gather values. + /// + /// unlike the real one that actually binds expressions and + /// uses semantic data to build up __all__ information, this one's purpose is gathering cheap and fast but might be incorrect data + /// until more expensive analysis is done. + /// + internal class AllVariableCollector : PythonWalker { + private const string AllVariableName = "__all__"; + private readonly CancellationToken _cancellationToken; + + /// + /// names assigned to __all__ + /// + public readonly HashSet Names; + + public AllVariableCollector(CancellationToken cancellationToken) { + _cancellationToken = cancellationToken; + Names = new HashSet(); + } + + public override bool Walk(AssignmentStatement node) { + _cancellationToken.ThrowIfCancellationRequested(); + + // make sure we are dealing with __all__ assignment + if (node.Left.OfType().Any(n => n.Name == AllVariableName)) { + AddNames(node.Right as ListExpression); + } + + return base.Walk(node); + } + + public override bool Walk(AugmentedAssignStatement node) { + _cancellationToken.ThrowIfCancellationRequested(); + + if (node.Operator == Parsing.PythonOperator.Add && + node.Left is NameExpression nex && + nex.Name == AllVariableName) { + AddNames(node.Right as ListExpression); + } + + return base.Walk(node); + } + + public override bool Walk(CallExpression node) { + _cancellationToken.ThrowIfCancellationRequested(); + + if (node.Args.Count > 0 && + node.Target is MemberExpression me && + me.Target is NameExpression nex && + nex.Name == AllVariableName) { + var arg = node.Args[0].Expression; + + switch (me.Name) { + case "append": + AddName(arg); + break; + case "extend": + AddNames(arg as ListExpression); + break; + } + } + + return base.Walk(node); + } + + private void AddName(Expression item) { + if (item is ConstantExpression con && + con.Value is string name && + !string.IsNullOrEmpty(name)) { + Names.Add(name); + } + } + + private void AddNames(ListExpression list) { + // only support the form of __all__ = [ ... ] + if (list == null) { + return; + } + + foreach (var item in list.Items) { + AddName(item); + } + } + } +} diff --git a/src/LanguageServer/Impl/Indexing/IIndexManager.cs b/src/LanguageServer/Impl/Indexing/IIndexManager.cs index 9edaa18e0..1c4e7d871 100644 --- a/src/LanguageServer/Impl/Indexing/IIndexManager.cs +++ b/src/LanguageServer/Impl/Indexing/IIndexManager.cs @@ -17,16 +17,18 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Analysis.Core.DependencyResolution; using Microsoft.Python.Analysis.Documents; namespace Microsoft.Python.LanguageServer.Indexing { internal interface IIndexManager : IDisposable { - Task IndexWorkspace(CancellationToken ct = default); + Task IndexWorkspace(PathResolverSnapshot snapshot = null, CancellationToken ct = default); void ProcessNewFile(string path, IDocument doc); void ProcessClosedFile(string path); void ReIndexFile(string path, IDocument doc); void AddPendingDoc(IDocument doc); Task> HierarchicalDocumentSymbolsAsync(string path, CancellationToken cancellationToken = default); Task> WorkspaceSymbolsAsync(string query, int maxLength, CancellationToken cancellationToken = default); + Task> WorkspaceSymbolsAsync(string query, int maxLength, bool includeLibraries, CancellationToken cancellationToken = default); } } diff --git a/src/LanguageServer/Impl/Indexing/IndexManager.cs b/src/LanguageServer/Impl/Indexing/IndexManager.cs index d6ab967f8..8d3595d5b 100644 --- a/src/LanguageServer/Impl/Indexing/IndexManager.cs +++ b/src/LanguageServer/Impl/Indexing/IndexManager.cs @@ -16,9 +16,11 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Python.Analysis.Core.DependencyResolution; using Microsoft.Python.Analysis.Core.Interpreter; using Microsoft.Python.Analysis.Documents; using Microsoft.Python.Core.Diagnostics; @@ -30,13 +32,15 @@ namespace Microsoft.Python.LanguageServer.Indexing { internal class IndexManager : IIndexManager { private const int DefaultReIndexDelay = 350; - private readonly ISymbolIndex _symbolIndex; + private readonly PythonLanguageVersion _version; + private readonly ISymbolIndex _userCodeSymbolIndex; + private readonly ISymbolIndex _libraryCodeSymbolIndex; private readonly IFileSystem _fileSystem; private readonly string _workspaceRootPath; private readonly string[] _includeFiles; private readonly string[] _excludeFiles; private readonly DisposableBag _disposables = new DisposableBag(nameof(IndexManager)); - private readonly ConcurrentDictionary _pendingDocs = new ConcurrentDictionary(new UriDocumentComparer()); + private readonly ConcurrentDictionary _pendingDocs = new ConcurrentDictionary(UriDocumentComparer.Instance); private readonly DisposeToken _disposeToken = DisposeToken.Create(); public IndexManager(IFileSystem fileSystem, PythonLanguageVersion version, string rootPath, string[] includeFiles, @@ -46,35 +50,47 @@ public IndexManager(IFileSystem fileSystem, PythonLanguageVersion version, strin Check.ArgumentNotNull(nameof(excludeFiles), excludeFiles); Check.ArgumentNotNull(nameof(idleTimeService), idleTimeService); + _version = version; _fileSystem = fileSystem; _workspaceRootPath = rootPath; _includeFiles = includeFiles; _excludeFiles = excludeFiles; - _symbolIndex = new SymbolIndex(_fileSystem, version); + _userCodeSymbolIndex = new SymbolIndex(_fileSystem, version); + _libraryCodeSymbolIndex = new SymbolIndex(_fileSystem, version, libraryMode: true); + idleTimeService.Idle += OnIdle; _disposables - .Add(_symbolIndex) + .Add(_userCodeSymbolIndex) + .Add(_libraryCodeSymbolIndex) .Add(() => _disposeToken.TryMarkDisposed()) .Add(() => idleTimeService.Idle -= OnIdle); } public int ReIndexingDelay { get; set; } = DefaultReIndexDelay; - public Task IndexWorkspace(CancellationToken ct = default) { + public Task IndexWorkspace(PathResolverSnapshot snapshot = null, CancellationToken ct = default) { var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, _disposeToken.CancellationToken); var linkedCt = linkedCts.Token; return Task.Run(() => { - foreach (var fileInfo in WorkspaceFiles()) { - linkedCt.ThrowIfCancellationRequested(); - if (ModulePath.IsPythonSourceFile(fileInfo.FullName)) { - _symbolIndex.Parse(fileInfo.FullName); - } - } + var userFiles = WorkspaceFiles(); + CreateIndices(userFiles, _userCodeSymbolIndex, linkedCt); + + // index library files if asked + CreateIndices(LibraryFiles(snapshot).Except(userFiles, FileSystemInfoComparer.Instance), _libraryCodeSymbolIndex, linkedCt); }, linkedCt).ContinueWith(_ => linkedCts.Dispose()); } + private void CreateIndices(IEnumerable files, ISymbolIndex symbolIndex, CancellationToken cancellationToken) { + foreach (var fileInfo in files) { + cancellationToken.ThrowIfCancellationRequested(); + if (ModulePath.IsPythonSourceFile(fileInfo.FullName)) { + symbolIndex.Parse(fileInfo.FullName); + } + } + } + private IEnumerable WorkspaceFiles() { if (string.IsNullOrEmpty(_workspaceRootPath)) { return Enumerable.Empty(); @@ -82,11 +98,20 @@ private IEnumerable WorkspaceFiles() { return _fileSystem.GetDirectoryInfo(_workspaceRootPath).EnumerateFileSystemInfos(_includeFiles, _excludeFiles); } + private IEnumerable LibraryFiles(PathResolverSnapshot snapshot) { + if (snapshot == null) { + return Enumerable.Empty(); + } + + var includeImplicit = !ModulePath.PythonVersionRequiresInitPyFiles(_version.ToVersion()); + return snapshot.GetAllImportableModuleFilePaths(includeImplicit).Select(p => new FileInfoProxy(new FileInfo(p))); + } + public void ProcessClosedFile(string path) { if (IsFileOnWorkspace(path)) { - _symbolIndex.Parse(path); + _userCodeSymbolIndex.Parse(path); } else { - _symbolIndex.Delete(path); + _userCodeSymbolIndex.Delete(path); } } @@ -99,28 +124,43 @@ private bool IsFileOnWorkspace(string path) { } public void ProcessNewFile(string path, IDocument doc) { - _symbolIndex.Add(path, doc); + _userCodeSymbolIndex.Add(path, doc); } public void ReIndexFile(string path, IDocument doc) { - _symbolIndex.ReIndex(path, doc); + _userCodeSymbolIndex.ReIndex(path, doc); } public void Dispose() { _disposables.TryDispose(); } - public Task> HierarchicalDocumentSymbolsAsync(string path, CancellationToken cancellationToken = default) { - return _symbolIndex.HierarchicalDocumentSymbolsAsync(path, cancellationToken); + public async Task> HierarchicalDocumentSymbolsAsync(string path, CancellationToken cancellationToken = default) { + var result = await _userCodeSymbolIndex.HierarchicalDocumentSymbolsAsync(path, cancellationToken); + if (result.Count > 0) { + return result; + } + + return await _libraryCodeSymbolIndex.HierarchicalDocumentSymbolsAsync(path, cancellationToken); } public Task> WorkspaceSymbolsAsync(string query, int maxLength, CancellationToken cancellationToken = default) { - return _symbolIndex.WorkspaceSymbolsAsync(query, maxLength, cancellationToken); + return WorkspaceSymbolsAsync(query, maxLength, includeLibraries: false, cancellationToken); + } + + public async Task> WorkspaceSymbolsAsync(string query, int maxLength, bool includeLibraries, CancellationToken cancellationToken = default) { + var userCodeResult = await _userCodeSymbolIndex.WorkspaceSymbolsAsync(query, maxLength, cancellationToken); + if (includeLibraries == false) { + return userCodeResult; + } + + var libraryCodeResult = await _libraryCodeSymbolIndex.WorkspaceSymbolsAsync(query, maxLength, cancellationToken); + return userCodeResult.Concat(libraryCodeResult).ToList(); } public void AddPendingDoc(IDocument doc) { _pendingDocs.TryAdd(doc, DateTime.Now); - _symbolIndex.MarkAsPending(doc.Uri.AbsolutePath); + _userCodeSymbolIndex.MarkAsPending(doc.Uri.AbsolutePath); } private void OnIdle(object sender, EventArgs _) { @@ -137,10 +177,21 @@ private void ReIndexPendingDocsAsync() { } private class UriDocumentComparer : IEqualityComparer { + public static readonly UriDocumentComparer Instance = new UriDocumentComparer(); + + private UriDocumentComparer() { } public bool Equals(IDocument x, IDocument y) => x.Uri.Equals(y.Uri); public int GetHashCode(IDocument obj) => obj.Uri.GetHashCode(); } + + private class FileSystemInfoComparer : IEqualityComparer { + public static readonly FileSystemInfoComparer Instance = new FileSystemInfoComparer(); + + private FileSystemInfoComparer() { } + public bool Equals(IFileSystemInfo x, IFileSystemInfo y) => x?.FullName == y?.FullName; + public int GetHashCode(IFileSystemInfo obj) => obj.FullName.GetHashCode(); + } } } diff --git a/src/LanguageServer/Impl/Indexing/MostRecentDocumentSymbols.cs b/src/LanguageServer/Impl/Indexing/MostRecentDocumentSymbols.cs index 8a57355a8..a8023c6b2 100644 --- a/src/LanguageServer/Impl/Indexing/MostRecentDocumentSymbols.cs +++ b/src/LanguageServer/Impl/Indexing/MostRecentDocumentSymbols.cs @@ -28,6 +28,7 @@ namespace Microsoft.Python.LanguageServer.Indexing { internal sealed class MostRecentDocumentSymbols : IMostRecentDocumentSymbols { private readonly IIndexParser _indexParser; private readonly string _path; + private readonly bool _library; // Only used to cancel all work when this object gets disposed. private readonly CancellationTokenSource _cts = new CancellationTokenSource(); @@ -37,9 +38,10 @@ internal sealed class MostRecentDocumentSymbols : IMostRecentDocumentSymbols { private TaskCompletionSource> _tcs = new TaskCompletionSource>(); private CancellationTokenSource _workCts; - public MostRecentDocumentSymbols(string path, IIndexParser indexParser) { + public MostRecentDocumentSymbols(string path, IIndexParser indexParser, bool library) { _path = path; _indexParser = indexParser; + _library = library; } public Task> GetSymbolsAsync(CancellationToken ct = default) { @@ -95,7 +97,7 @@ public void MarkAsPending() { public void Dispose() { lock (_lock) { _tcs.TrySetCanceled(); - + try { _workCts?.Dispose(); } catch (ObjectDisposedException) { @@ -134,7 +136,7 @@ private async Task> IndexAsync(IDocument doc, } cancellationToken.ThrowIfCancellationRequested(); - var walker = new SymbolIndexWalker(ast); + var walker = new SymbolIndexWalker(ast, _library, cancellationToken); ast.Walk(walker); return walker.Symbols; } @@ -143,7 +145,7 @@ private async Task> ParseAsync(CancellationTok try { var ast = await _indexParser.ParseAsync(_path, cancellationToken); cancellationToken.ThrowIfCancellationRequested(); - var walker = new SymbolIndexWalker(ast); + var walker = new SymbolIndexWalker(ast, _library, cancellationToken); ast.Walk(walker); return walker.Symbols; } catch (Exception e) when (e is IOException || e is UnauthorizedAccessException) { diff --git a/src/LanguageServer/Impl/Indexing/SymbolIndex.cs b/src/LanguageServer/Impl/Indexing/SymbolIndex.cs index 07628af36..7b97905b7 100644 --- a/src/LanguageServer/Impl/Indexing/SymbolIndex.cs +++ b/src/LanguageServer/Impl/Indexing/SymbolIndex.cs @@ -29,11 +29,13 @@ internal sealed class SymbolIndex : ISymbolIndex { private readonly DisposableBag _disposables = new DisposableBag(nameof(SymbolIndex)); private readonly ConcurrentDictionary _index; private readonly IIndexParser _indexParser; + private readonly bool _libraryMode; - public SymbolIndex(IFileSystem fileSystem, PythonLanguageVersion version) { - var comparer = PathEqualityComparer.Instance; - _index = new ConcurrentDictionary(comparer); + public SymbolIndex(IFileSystem fileSystem, PythonLanguageVersion version, bool libraryMode = false) { + _index = new ConcurrentDictionary(PathEqualityComparer.Instance); _indexParser = new IndexParser(fileSystem, version); + _libraryMode = libraryMode; + _disposables .Add(_indexParser) .Add(() => { @@ -97,7 +99,7 @@ private IEnumerable WorkspaceSymbolsQuery(string path, string query, return DecorateWithParentsName((sym.Children ?? Enumerable.Empty()).ToList(), sym.Name); }); return treeSymbols.Where(sym => sym.symbol.Name.ContainsOrdinal(query, ignoreCase: true)) - .Select(sym => new FlatSymbol(sym.symbol.Name, sym.symbol.Kind, path, sym.symbol.SelectionRange, sym.parentName)); + .Select(sym => new FlatSymbol(sym.symbol.Name, sym.symbol.Kind, path, sym.symbol.SelectionRange, sym.parentName, sym.symbol._existInAllVariable)); } private static IEnumerable<(HierarchicalSymbol symbol, string parentName)> DecorateWithParentsName( @@ -106,7 +108,7 @@ private IEnumerable WorkspaceSymbolsQuery(string path, string query, } private IMostRecentDocumentSymbols MakeMostRecentDocSymbols(string path) { - return new MostRecentDocumentSymbols(path, _indexParser); + return new MostRecentDocumentSymbols(path, _indexParser, _libraryMode); } public void Dispose() { diff --git a/src/LanguageServer/Impl/Indexing/SymbolIndexWalker.cs b/src/LanguageServer/Impl/Indexing/SymbolIndexWalker.cs index e4f68a8c0..05d00859c 100644 --- a/src/LanguageServer/Impl/Indexing/SymbolIndexWalker.cs +++ b/src/LanguageServer/Impl/Indexing/SymbolIndexWalker.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using System.Linq; using System.Text.RegularExpressions; +using System.Threading; using Microsoft.Python.Core; using Microsoft.Python.Parsing.Ast; @@ -25,17 +26,27 @@ internal class SymbolIndexWalker : PythonWalker { private static readonly Regex ConstantLike = new Regex(@"^[\p{Lu}\p{N}_]+$", RegexOptions.Compiled); private readonly PythonAst _ast; + private readonly bool _libraryMode; private readonly SymbolStack _stack = new SymbolStack(); + private readonly HashSet _namesInAllVariable; - public SymbolIndexWalker(PythonAst ast) { + public SymbolIndexWalker(PythonAst ast, bool libraryMode = false, CancellationToken cancellationToken = default) { _ast = ast; + _libraryMode = libraryMode; + + var collector = new AllVariableCollector(cancellationToken); + _ast.Walk(collector); + + _namesInAllVariable = collector.Names; } public IReadOnlyList Symbols => _stack.Root; public override bool Walk(ClassDefinition node) { _stack.Enter(SymbolKind.Class); - node.Body?.Walk(this); + + WalkIfNotLibraryMode(node.Body); + var children = _stack.Exit(); _stack.AddSymbol(new HierarchicalSymbol( @@ -44,7 +55,8 @@ public override bool Walk(ClassDefinition node) { node.GetSpan(_ast), node.NameExpression.GetSpan(_ast), children, - FunctionKind.Class + FunctionKind.Class, + ExistInAllVariable(node.Name) )); return false; @@ -55,55 +67,67 @@ public override bool Walk(FunctionDefinition node) { foreach (var p in node.Parameters) { AddVarSymbol(p.NameExpression); } - node.Body?.Walk(this); + + // don't bother to walk down locals for libraries + // we don't care those for libraries + WalkIfNotLibraryMode(node.Body); + var children = _stack.Exit(); + SymbolKind symbolKind; + string functionKind; + GetKinds(node, out symbolKind, out functionKind); + var span = node.GetSpan(_ast); var ds = new HierarchicalSymbol( node.Name, - SymbolKind.Function, + symbolKind, span, node.IsLambda ? span : node.NameExpression.GetSpan(_ast), children, - FunctionKind.Function + functionKind, + ExistInAllVariable(node.Name) ); + _stack.AddSymbol(ds); + return false; + } + + private void GetKinds(FunctionDefinition node, out SymbolKind symbolKind, out string functionKind) { + symbolKind = SymbolKind.Function; + functionKind = FunctionKind.Function; + if (_stack.Parent == SymbolKind.Class) { - switch (ds.Name) { + switch (node.Name) { case "__init__": - ds.Kind = SymbolKind.Constructor; + symbolKind = SymbolKind.Constructor; break; case var name when DoubleUnderscore.IsMatch(name): - ds.Kind = SymbolKind.Operator; + symbolKind = SymbolKind.Operator; break; default: - ds.Kind = SymbolKind.Method; + symbolKind = SymbolKind.Method; if (node.Decorators != null) { foreach (var dec in node.Decorators.Decorators) { var maybeKind = DecoratorExpressionToKind(dec); if (maybeKind.HasValue) { - ds.Kind = maybeKind.Value.kind; - ds._functionKind = maybeKind.Value.functionKind; + symbolKind = maybeKind.Value.kind; + functionKind = maybeKind.Value.functionKind; break; } } } - break; } } - - _stack.AddSymbol(ds); - - return false; } public override bool Walk(ImportStatement node) { foreach (var (nameNode, nameString) in node.Names.Zip(node.AsNames, (name, asName) => asName != null ? (asName, asName.Name) : ((Node)name, name.MakeString()))) { var span = nameNode.GetSpan(_ast); - _stack.AddSymbol(new HierarchicalSymbol(nameString, SymbolKind.Module, span)); + _stack.AddSymbol(new HierarchicalSymbol(nameString, SymbolKind.Module, span, existInAllVariable: ExistInAllVariable(nameString))); } return false; @@ -116,14 +140,15 @@ public override bool Walk(FromImportStatement node) { foreach (var name in node.Names.Zip(node.AsNames, (name, asName) => asName ?? name)) { var span = name.GetSpan(_ast); - _stack.AddSymbol(new HierarchicalSymbol(name.Name, SymbolKind.Module, span)); + _stack.AddSymbol(new HierarchicalSymbol(name.Name, SymbolKind.Module, span, existInAllVariable: ExistInAllVariable(name.Name))); } return false; } public override bool Walk(AssignmentStatement node) { - node.Right?.Walk(this); + WalkIfNotLibraryMode(node.Right); + foreach (var exp in node.Left) { AddVarSymbolRecursive(exp); } @@ -132,13 +157,15 @@ public override bool Walk(AssignmentStatement node) { } public override bool Walk(NamedExpression node) { - node.Value?.Walk(this); + WalkIfNotLibraryMode(node.Value); + AddVarSymbolRecursive(node.Target); return false; } public override bool Walk(AugmentedAssignStatement node) { - node.Right?.Walk(this); + WalkIfNotLibraryMode(node.Right); + AddVarSymbol(node.Left as NameExpression); return false; } @@ -174,11 +201,19 @@ public override bool Walk(ForStatement node) { } public override bool Walk(ComprehensionFor node) { + if (_libraryMode) { + return false; + } + AddVarSymbolRecursive(node.Left); return base.Walk(node); } public override bool Walk(ListComprehension node) { + if (_libraryMode) { + return false; + } + _stack.Enter(SymbolKind.None); return base.Walk(node); } @@ -186,6 +221,10 @@ public override bool Walk(ListComprehension node) { public override void PostWalk(ListComprehension node) => ExitComprehension(node); public override bool Walk(DictionaryComprehension node) { + if (_libraryMode) { + return false; + } + _stack.Enter(SymbolKind.None); return base.Walk(node); } @@ -193,6 +232,10 @@ public override bool Walk(DictionaryComprehension node) { public override void PostWalk(DictionaryComprehension node) => ExitComprehension(node); public override bool Walk(SetComprehension node) { + if (_libraryMode) { + return false; + } + _stack.Enter(SymbolKind.None); return base.Walk(node); } @@ -200,6 +243,10 @@ public override bool Walk(SetComprehension node) { public override void PostWalk(SetComprehension node) => ExitComprehension(node); public override bool Walk(GeneratorExpression node) { + if (_libraryMode) { + return false; + } + _stack.Enter(SymbolKind.None); return base.Walk(node); } @@ -207,6 +254,10 @@ public override bool Walk(GeneratorExpression node) { public override void PostWalk(GeneratorExpression node) => ExitComprehension(node); private void ExitComprehension(Comprehension node) { + if (_libraryMode) { + return; + } + var children = _stack.Exit(); var span = node.GetSpan(_ast); @@ -237,7 +288,7 @@ private void AddVarSymbol(NameExpression node) { var span = node.GetSpan(_ast); - _stack.AddSymbol(new HierarchicalSymbol(node.Name, kind, span)); + _stack.AddSymbol(new HierarchicalSymbol(node.Name, kind, span, existInAllVariable: ExistInAllVariable(node.Name))); } private void AddVarSymbolRecursive(Expression node) { @@ -281,6 +332,14 @@ private void AddVarSymbolRecursive(Expression node) { return null; } + private void WalkIfNotLibraryMode(Node node) { + if (_libraryMode) { + return; + } + + node?.Walk(this); + } + private bool NameIsProperty(string name) => name == "property" || name == "abstractproperty" @@ -311,6 +370,10 @@ private void WalkAndDeclareAll(IEnumerable nodes) { } } + private bool ExistInAllVariable(string name) { + return _namesInAllVariable.Contains(name); + } + private class SymbolStack { private readonly Stack<(SymbolKind? kind, List symbols)> _symbols; private readonly Stack> _declared = new Stack>(new[] { new HashSet() }); diff --git a/src/LanguageServer/Impl/Indexing/Symbols.cs b/src/LanguageServer/Impl/Indexing/Symbols.cs index 0cdac218d..eecb55725 100644 --- a/src/LanguageServer/Impl/Indexing/Symbols.cs +++ b/src/LanguageServer/Impl/Indexing/Symbols.cs @@ -14,6 +14,7 @@ // permissions and limitations under the License. using System.Collections.Generic; +using System.Diagnostics; using Microsoft.Python.Core.Text; namespace Microsoft.Python.LanguageServer.Indexing { @@ -58,16 +59,18 @@ internal class FunctionKind { } // Analagous to LSP's DocumentSymbol. - internal class HierarchicalSymbol { - public string Name; - public string Detail; - public SymbolKind Kind; - public bool? Deprecated; - public SourceSpan Range; - public SourceSpan SelectionRange; - public IList Children; + [DebuggerDisplay("{Name}, {Kind}")] + internal sealed class HierarchicalSymbol { + public readonly string Name; + public readonly string Detail; + public readonly SymbolKind Kind; + public readonly bool? Deprecated; + public readonly SourceSpan Range; + public readonly SourceSpan SelectionRange; + public readonly IList Children; - public string _functionKind; + public readonly string _functionKind; + public readonly bool _existInAllVariable; public HierarchicalSymbol( string name, @@ -75,7 +78,8 @@ public HierarchicalSymbol( SourceSpan range, SourceSpan? selectionRange = null, IList children = null, - string functionKind = FunctionKind.None + string functionKind = FunctionKind.None, + bool existInAllVariable = false ) { Name = name; Kind = kind; @@ -83,30 +87,36 @@ public HierarchicalSymbol( SelectionRange = selectionRange ?? range; Children = children; _functionKind = functionKind; + _existInAllVariable = existInAllVariable; } } // Analagous to LSP's SymbolInformation. - internal class FlatSymbol { - public string Name; - public SymbolKind Kind; - public bool? Deprecated; - public string DocumentPath; - public SourceSpan Range; - public string ContainerName; + [DebuggerDisplay("{ContainerName}:{Name}, {Kind}")] + internal sealed class FlatSymbol { + public readonly string Name; + public readonly SymbolKind Kind; + public readonly bool? Deprecated; + public readonly string DocumentPath; + public readonly SourceSpan Range; + public readonly string ContainerName; + + public readonly bool _existInAllVariable; public FlatSymbol( string name, SymbolKind kind, string documentPath, SourceSpan range, - string containerName = null + string containerName = null, + bool existInAllVariable = false ) { Name = name; Kind = kind; DocumentPath = documentPath; Range = range; ContainerName = containerName; + _existInAllVariable = existInAllVariable; } } } diff --git a/src/LanguageServer/Impl/LanguageServer.cs b/src/LanguageServer/Impl/LanguageServer.cs index 64bb7f293..da05462b9 100644 --- a/src/LanguageServer/Impl/LanguageServer.cs +++ b/src/LanguageServer/Impl/LanguageServer.cs @@ -215,11 +215,11 @@ public async Task DocumentSymbol(JToken token, CancellationTok } } - //[JsonRpcMethod("textDocument/codeAction")] - //public async Task CodeAction(JToken token, CancellationToken cancellationToken) { - // await _prioritizer.DefaultPriorityAsync(cancellationToken); - // return await _server.CodeAction(ToObject(token), cancellationToken); - //} + [JsonRpcMethod("textDocument/codeAction")] + public async Task CodeAction(JToken token, CancellationToken cancellationToken) { + await _prioritizer.DefaultPriorityAsync(cancellationToken); + return await _server.CodeAction(ToObject(token), cancellationToken); + } //[JsonRpcMethod("textDocument/codeLens")] //public async Task CodeLens(JToken token, CancellationToken cancellationToken) { diff --git a/src/LanguageServer/Impl/Protocol/Classes.cs b/src/LanguageServer/Impl/Protocol/Classes.cs index bdb39e38a..88842f6e6 100644 --- a/src/LanguageServer/Impl/Protocol/Classes.cs +++ b/src/LanguageServer/Impl/Protocol/Classes.cs @@ -363,7 +363,32 @@ public sealed class DefinitionCapabilities { public bool dynamicRegistration; } public DefinitionCapabilities definition; [Serializable] - public sealed class CodeActionCapabilities { public bool dynamicRegistration; } + public sealed class CodeActionCapabilities { + public bool dynamicRegistration; + // + // The client support code action literals as a valid + // response of the `textDocument/codeAction` request. + // + // Since 3.8.0 + // + public class CodeActionLiteralSupport { + // + // The code action kind is support with the following value + // set. + // + public class CodeActionKind { + // + // The code action kind values the client supports. When this + // property exists the client also guarantees that it will + // handle values outside its set gracefully and falls back + // to a default value when unknown. + // + public string[] valueSet; + } + public CodeActionKind codeActionKind; + } + public CodeActionLiteralSupport codeActionLiteralSupport; + } public CodeActionCapabilities codeAction; [Serializable] @@ -412,6 +437,17 @@ public sealed class CodeLensOptions { public bool resolveProvider; } + [Serializable] + public sealed class CodeActionOptions { + // + // CodeActionKinds that this server may return. + // + // The list of kinds may be generic, such as `CodeActionKind.Refactor`, or the server + // may list out every specific kind they provide. + // + public string[] codeActionKinds; + } + [Serializable] public sealed class DocumentOnTypeFormattingOptions { public string firstTriggerCharacter; @@ -456,7 +492,7 @@ public sealed class ServerCapabilities { public bool documentHighlightProvider; public bool documentSymbolProvider; public bool workspaceSymbolProvider; - public bool codeActionProvider; + public CodeActionOptions codeActionProvider; public CodeLensOptions codeLensProvider; public bool documentFormattingProvider; public bool documentRangeFormattingProvider; @@ -741,4 +777,35 @@ public sealed class PublishDiagnosticsParams { [JsonProperty] public Diagnostic[] diagnostics; } + + // + // Summary: + // A class representing a change that can be performed in code. A CodeAction must + // either set "edit" or "command". + // If both are supplied, the edit will be applied first, then the command will be executed. + [Serializable] + public sealed class CodeAction { + // + // Summary: + // Gets or sets the human readable title for this code action. + public string title; + // + // The kind of the code action. + // + // Used to filter code actions. + // + public string kind; + // + // The diagnostics that this code action resolves. + // + public Diagnostic[] diagnostics; + // + // Summary: + // Gets or sets the workspace edit that this code action performs. + public WorkspaceEdit edit; + // + // Summary: + // Gets or sets the command that this code action executes. + public Command command; + } } diff --git a/src/LanguageServer/Impl/Protocol/Diagnostic.cs b/src/LanguageServer/Impl/Protocol/Diagnostic.cs index 68dcac78b..9b1913792 100644 --- a/src/LanguageServer/Impl/Protocol/Diagnostic.cs +++ b/src/LanguageServer/Impl/Protocol/Diagnostic.cs @@ -51,4 +51,76 @@ public enum DiagnosticSeverity : int { Information = 3, Hint = 4 } + + // + // The kind of a code action. + // + // Kinds are a hierarchical list of identifiers separated by `.`, e.g. `"refactor.extract.function"`. + // + // The set of kinds is open and client needs to announce the kinds it supports to the server during + // initialization. + // + // + // A set of predefined code action kinds + // + public static class CodeActionKind { + // + // Empty kind. + // + public const string Empty = ""; + // + // Base kind for quickfix actions: 'quickfix' + // + public const string QuickFix = "quickfix"; + // + // Base kind for refactoring actions: 'refactor' + // + public const string Refactor = "refactor"; + // + // Base kind for refactoring extraction actions: 'refactor.extract' + // + // Example extract actions: + // + // - Extract method + // - Extract function + // - Extract variable + // - Extract interface from class + // - ... + // + public const string RefactorExtract = "refactor.extract"; + // + // Base kind for refactoring inline actions: 'refactor.inline' + // + // Example inline actions: + // + // - Inline function + // - Inline variable + // - Inline constant + // - ... + // + public const string RefactorInline = "refactor.inline"; + // + // Base kind for refactoring rewrite actions: 'refactor.rewrite' + // + // Example rewrite actions: + // + // - Convert JavaScript function to class + // - Add or remove parameter + // - Encapsulate field + // - Make method static + // - Move method to base class + // - ... + // + public const string RefactorRewrite = "refactor.rewrite"; + // + // Base kind for source actions: `source` + // + // Source code actions apply to the entire file. + // + public const string Source = "source"; + // + // Base kind for an organize imports source action: `source.organizeImports` + // + public const string SourceOrganizeImports = "source.organizeImports"; + } } diff --git a/src/LanguageServer/Impl/Protocol/Messages.cs b/src/LanguageServer/Impl/Protocol/Messages.cs index ab253a760..35c0f6f2f 100644 --- a/src/LanguageServer/Impl/Protocol/Messages.cs +++ b/src/LanguageServer/Impl/Protocol/Messages.cs @@ -189,13 +189,17 @@ public struct CodeActionParams { [Serializable] public sealed class CodeActionContext { + // + // An array of diagnostics. + // public Diagnostic[] diagnostics; - - /// - /// The intended version that diagnostic locations apply to. The request may - /// fail if the server cannot map correctly. - /// - public int? _version; + // + // Requested kind of actions to return. + // + // Actions not of this kind are filtered out by the client before being shown. So servers + // can omit computing them. + // + public string[] only; } [Serializable] diff --git a/src/LanguageServer/Impl/Resources.Designer.cs b/src/LanguageServer/Impl/Resources.Designer.cs index b413d0dde..07cdc2c42 100644 --- a/src/LanguageServer/Impl/Resources.Designer.cs +++ b/src/LanguageServer/Impl/Resources.Designer.cs @@ -105,6 +105,15 @@ internal static string Error_InvalidCachePath { } } + /// + /// Looks up a localized string similar to import locally '{0}'. + /// + internal static string ImportLocally { + get { + return ResourceManager.GetString("ImportLocally", resourceCulture); + } + } + /// /// Looks up a localized string similar to Initializing for generic interpreter. /// diff --git a/src/LanguageServer/Impl/Resources.resx b/src/LanguageServer/Impl/Resources.resx index e060c233d..fa819a42d 100644 --- a/src/LanguageServer/Impl/Resources.resx +++ b/src/LanguageServer/Impl/Resources.resx @@ -132,6 +132,9 @@ Specified cache folder does not exist. Switching to default. + + import locally '{0}' + Initializing for generic interpreter diff --git a/src/LanguageServer/Impl/Sources/CodeActionSource.cs b/src/LanguageServer/Impl/Sources/CodeActionSource.cs new file mode 100644 index 000000000..21d787c6e --- /dev/null +++ b/src/LanguageServer/Impl/Sources/CodeActionSource.cs @@ -0,0 +1,78 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + + +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Python.Analysis; +using Microsoft.Python.Analysis.Diagnostics; +using Microsoft.Python.Core; +using Microsoft.Python.Core.Collections; +using Microsoft.Python.LanguageServer.CodeActions; +using Microsoft.Python.LanguageServer.Protocol; +using Microsoft.Python.Parsing.Ast; + +namespace Microsoft.Python.LanguageServer.Sources { + internal sealed partial class CodeActionSource { + private static readonly ImmutableArray _codeActionProviders = + ImmutableArray.Create(MissingImportCodeActionProvider.Instance); + + private readonly IServiceContainer _services; + + public CodeActionSource(IServiceContainer services) { + _services = services; + } + + public async Task GetCodeActionsAsync(IDocumentAnalysis analysis, Diagnostic[] diagnostics, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + + var results = new List(); + + foreach (var diagnostic in GetMatchingDiagnostics(analysis, diagnostics, cancellationToken)) { + foreach (var codeActionProvider in _codeActionProviders) { + if (codeActionProvider.FixableDiagnostics.Any(code => code == diagnostic.ErrorCode)) { + results.AddRange(await codeActionProvider.GetCodeActionsAsync(analysis, diagnostic, cancellationToken)); + } + } + } + + return results.ToArray(); + } + + private IEnumerable GetMatchingDiagnostics(IDocumentAnalysis analysis, Diagnostic[] diagnostics, CancellationToken cancellationToken) { + var diagnosticService = _services.GetService(); + + // we assume diagnostic service has the latest results + if (diagnosticService == null || !diagnosticService.Diagnostics.TryGetValue(analysis.Document.Uri, out var latestDiagnostics)) { + yield break; + } + + foreach (var diagnostic in latestDiagnostics) { + cancellationToken.ThrowIfCancellationRequested(); + + if (diagnostics.Any(d => AreEqual(d, diagnostic))) { + yield return diagnostic; + } + } + + bool AreEqual(Diagnostic diagnostic1, DiagnosticsEntry diagnostic2) { + return diagnostic1.code == diagnostic2.ErrorCode && + diagnostic1.range.ToSourceSpan() == diagnostic2.SourceSpan; + } + } + } +} diff --git a/src/LanguageServer/Impl/Sources/HoverSource.cs b/src/LanguageServer/Impl/Sources/HoverSource.cs index 5aaa2f9af..4f843714e 100644 --- a/src/LanguageServer/Impl/Sources/HoverSource.cs +++ b/src/LanguageServer/Impl/Sources/HoverSource.cs @@ -54,27 +54,27 @@ public Hover GetHover(IDocumentAnalysis analysis, SourceLocation location) { var eval = analysis.ExpressionEvaluator; switch (statement) { case FromImportStatement fi when node is NameExpression nex: { - var contents = HandleFromImport(fi, location, hoverScopeStatement, analysis); - if (contents != null) { - return new Hover { - contents = contents, - range = range - }; - } - - break; + var contents = HandleFromImport(fi, location, hoverScopeStatement, analysis); + if (contents != null) { + return new Hover { + contents = contents, + range = range + }; } - case ImportStatement imp: { - var contents = HandleImport(imp, location, hoverScopeStatement, analysis); - if (contents != null) { - return new Hover { - contents = contents, - range = range - }; - } - break; + break; + } + case ImportStatement imp: { + var contents = HandleImport(imp, location, hoverScopeStatement, analysis); + if (contents != null) { + return new Hover { + contents = contents, + range = range + }; } + + break; + } } IMember value; @@ -96,13 +96,13 @@ public Hover GetHover(IDocumentAnalysis analysis, SourceLocation location) { IVariable variable = null; if (expr is NameExpression nex) { - analysis.ExpressionEvaluator.LookupNameInScopes(nex.Name, out _, out variable, LookupOptions.All); + eval.LookupNameInScopes(nex.Name, out _, out variable, LookupOptions.All); if (IsInvalidClassMember(variable, hoverScopeStatement, location.ToIndex(analysis.Ast))) { return null; } } - value = variable?.Value ?? analysis.ExpressionEvaluator.GetValueFromExpression(expr, LookupOptions.All); + value = variable?.Value ?? eval.GetValueFromExpression(expr, LookupOptions.All); type = value?.GetPythonType(); if (type == null) { return null; @@ -122,7 +122,7 @@ public Hover GetHover(IDocumentAnalysis analysis, SourceLocation location) { // In case of a member expression get the target since if we end up with method // of a generic class, the function will need specific type to determine its return // value correctly. I.e. in x.func() we need to determine type of x (self for func). - var v = analysis.ExpressionEvaluator.GetValueFromExpression(mex.Target); + var v = eval.GetValueFromExpression(mex.Target); self = v?.GetPythonType(); } diff --git a/src/LanguageServer/Impl/Utilities/UniqueNameGenerator.cs b/src/LanguageServer/Impl/Utilities/UniqueNameGenerator.cs new file mode 100644 index 000000000..752c9f569 --- /dev/null +++ b/src/LanguageServer/Impl/Utilities/UniqueNameGenerator.cs @@ -0,0 +1,114 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Python.Analysis; +using Microsoft.Python.Analysis.Analyzer.Expressions; +using Microsoft.Python.Parsing.Ast; +using Microsoft.Python.Parsing.Extensions; + +namespace Microsoft.Python.LanguageServer.Utilities { + /// + /// Generate unique identifier based on given context + /// + internal class UniqueNameGenerator { + private readonly IDocumentAnalysis _analysis; + private readonly ScopeStatement _scope; + private readonly bool _uniqueInModule; + + public static string Generate(IDocumentAnalysis analysis, int position, string name) { + var generator = new UniqueNameGenerator(analysis, position); + return generator.Generate(name); + } + + public static string Generate(IDocumentAnalysis analysis, string name) { + var generator = new UniqueNameGenerator(analysis, position: -1); + return generator.Generate(name); + } + + public UniqueNameGenerator(IDocumentAnalysis analysis, int position) { + _analysis = analysis; + _uniqueInModule = position < 0; + + if (!_uniqueInModule) { + var finder = new ExpressionFinder(analysis.Ast, new FindExpressionOptions() { Names = true }); + finder.Get(position, position, out _, out _, out _scope); + } + } + + public string Generate(string name) { + // for now, there is only 1 new name rule which is just incrementing count at the end. + int count = 0; + Func getNextName = () => { + return $"{name}{++count}"; + }; + + // for now, everything is fixed. and there is no knob to control what to consider when + // creating unique name and how to create new name if there is a conflict + if (_uniqueInModule) { + return GenerateModuleWideUniqueName(name, getNextName); + } else { + return GenerateContextBasedUniqueName(name, getNextName); + } + } + + private string GenerateModuleWideUniqueName(string name, Func getNextName) { + // think of a better way to do this. + var leafScopes = GetLeafScopes(_analysis.Ast.ChildNodesDepthFirst().OfType()); + + while (true) { + if (!leafScopes.Any(s => NameExist(name, s))) { + return name; + } + + name = getNextName(); + } + } + + private HashSet GetLeafScopes(IEnumerable scopes) { + var set = scopes.ToHashSet(); + foreach (var scope in set.ToList()) { + if (scope.Parent != null) { + set.Remove(scope.Parent); + } + } + + return set; + } + + private bool NameExist(string name, ScopeStatement scope) { + var eval = _analysis.ExpressionEvaluator; + using (eval.OpenScope(_analysis.Document, scope)) { + return eval.LookupNameInScopes(name, Analysis.Analyzer.LookupOptions.All) != null; + } + } + + private string GenerateContextBasedUniqueName(string name, Func getNextName) { + var eval = _analysis.ExpressionEvaluator; + using (eval.OpenScope(_analysis.Document, _scope)) { + while (true) { + var member = eval.LookupNameInScopes(name, Analysis.Analyzer.LookupOptions.All); + if (member == null) { + return name; + } + + name = getNextName(); + } + } + } + } +} diff --git a/src/LanguageServer/Test/MissingImportCodeActionTests.cs b/src/LanguageServer/Test/MissingImportCodeActionTests.cs new file mode 100644 index 000000000..afaa77669 --- /dev/null +++ b/src/LanguageServer/Test/MissingImportCodeActionTests.cs @@ -0,0 +1,522 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Python.Analysis; +using Microsoft.Python.Analysis.Analyzer; +using Microsoft.Python.Analysis.Documents; +using Microsoft.Python.Core.Idle; +using Microsoft.Python.Core.IO; +using Microsoft.Python.Core.Services; +using Microsoft.Python.Core.Text; +using Microsoft.Python.LanguageServer.CodeActions; +using Microsoft.Python.LanguageServer.Diagnostics; +using Microsoft.Python.LanguageServer.Indexing; +using Microsoft.Python.LanguageServer.Protocol; +using Microsoft.Python.LanguageServer.Sources; +using Microsoft.Python.LanguageServer.Tests.FluentAssertions; +using Microsoft.Python.Parsing.Ast; +using Microsoft.Python.Parsing.Tests; +using Microsoft.Python.UnitTests.Core; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using TestUtilities; + +namespace Microsoft.Python.LanguageServer.Tests { + [TestClass] + public class MissingImportCodeActionTests : LanguageServerTestBase { + public TestContext TestContext { get; set; } + + [TestInitialize] + public void TestInitialize() + => TestEnvironmentImpl.TestInitialize($"{TestContext.FullyQualifiedTestClassName}.{TestContext.TestName}"); + + [TestCleanup] + public void Cleanup() => TestEnvironmentImpl.TestCleanup(); + + [TestMethod, Priority(0)] + public async Task Missing() { + MarkupUtils.GetSpan(@"[|missingModule|]", out var code, out var span); + + var analysis = await GetAnalysisAsync(code); + var diagnostics = GetDiagnostics(analysis, span.ToSourceSpan(analysis.Ast), MissingImportCodeActionProvider.Instance.FixableDiagnostics); + diagnostics.Should().NotBeEmpty(); + + var codeActions = await new CodeActionSource(analysis.ExpressionEvaluator.Services).GetCodeActionsAsync(analysis, diagnostics, CancellationToken.None); + codeActions.Should().BeEmpty(); + } + + [TestMethod, Priority(0)] + public async Task TopModule() { + const string markup = @"{|insertionSpan:|}{|diagnostic:ntpath|}"; + + var (analysis, codeActions, insertionSpan) = await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + + var codeAction = codeActions.Single(); + var newText = "import ntpath" + Environment.NewLine + Environment.NewLine; + TestCodeAction(analysis.Document.Uri, codeAction, title: "import ntpath", insertionSpan, newText); + } + + [TestMethod, Priority(0), Ignore] + public async Task TopModuleFromFunctionInsertTop() { + const string markup = @"{|insertionSpan:|}def TestMethod(): + {|diagnostic:ntpath|}"; + + var (analysis, codeActions, insertionSpan) = await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + + codeActions.Should().HaveCount(2); + + var codeAction = codeActions.First(); + var newText = "import ntpath" + Environment.NewLine + Environment.NewLine; + TestCodeAction(analysis.Document.Uri, codeAction, title: "import ntpath", insertionSpan, newText); + } + + [TestMethod, Priority(0), Ignore] + public async Task TopModuleLocally() { + const string markup = @"def TestMethod(): +{|insertionSpan:|} {|diagnostic:ntpath|}"; + + var (analysis, codeActions, insertionSpan) = await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + + codeActions.Should().HaveCount(2); + + var codeAction = codeActions[1]; + var newText = " import ntpath" + Environment.NewLine + Environment.NewLine; + TestCodeAction(analysis.Document.Uri, codeAction, title: string.Format(Resources.ImportLocally, "import ntpath"), insertionSpan, newText); + } + + [TestMethod, Priority(0)] + public async Task SubModule() { + await TestCodeActionAsync( + @"{|insertionSpan:|}{|diagnostic:util|}", + title: "from ctypes import util", + newText: "from ctypes import util" + Environment.NewLine + Environment.NewLine); + } + + [TestMethod, Priority(0)] + public async Task SubModuleUpdate() { + await TestCodeActionAsync( + @"{|insertionSpan:from ctypes import util|} +{|diagnostic:test|}", + title: "from ctypes import test, util", + newText: "from ctypes import test, util"); + } + + [TestMethod, Priority(0), Ignore] + public async Task SubModuleUpdateLocally() { + await TestCodeActionAsync( + @"def TestMethod(): + {|insertionSpan:from ctypes import util|} + {|diagnostic:test|}", + title: string.Format(Resources.ImportLocally, "from ctypes import test, util"), + newText: "from ctypes import test, util"); + } + + [TestMethod, Priority(0)] + public async Task SubModuleFromFunctionInsertTop() { + await TestCodeActionAsync( + @"{|insertionSpan:|}def TestMethod(): + from ctypes import util + {|diagnostic:test|}", + title: "from ctypes import test", + newText: "from ctypes import test" + Environment.NewLine + Environment.NewLine); + } + + [TestMethod, Priority(0)] + public async Task AfterExistingImport() { + await TestCodeActionAsync( + @"from os import path +{|insertionSpan:|} +{|diagnostic:util|}", + title: "from ctypes import util", + newText: "from ctypes import util" + Environment.NewLine); + } + + [TestMethod, Priority(0)] + public async Task ReplaceExistingImport() { + await TestCodeActionAsync( + @"from os import path +{|insertionSpan:from ctypes import test|} +import socket + +{|diagnostic:util|}", + title: "from ctypes import test, util", + newText: "from ctypes import test, util"); + } + + [TestMethod, Priority(0), Ignore] + public async Task AfterExistingImportLocally() { + await TestCodeActionAsync( + @"def TestMethod(): + from os import path +{|insertionSpan:|} + {|diagnostic:util|}", + title: string.Format(Resources.ImportLocally, "from ctypes import util"), + newText: " from ctypes import util" + Environment.NewLine); + } + + [TestMethod, Priority(0), Ignore] + public async Task ReplaceExistingImportLocally() { + await TestCodeActionAsync( + @"def TestMethod(): + from os import path + {|insertionSpan:from ctypes import test|} + import socket + + {|diagnostic:util|}", + title: string.Format(Resources.ImportLocally, "from ctypes import test, util"), + newText: "from ctypes import test, util"); + } + + [TestMethod, Priority(0), Ignore] + public async Task CodeActionOrdering() { + MarkupUtils.GetSpan(@"def TestMethod(): + [|test|]", out var code, out var span); + + var analysis = await GetAnalysisAsync(code); + var diagnostics = GetDiagnostics(analysis, span.ToSourceSpan(analysis.Ast), MissingImportCodeActionProvider.Instance.FixableDiagnostics); + diagnostics.Should().NotBeEmpty(); + + var codeActions = await new CodeActionSource(analysis.ExpressionEvaluator.Services).GetCodeActionsAsync(analysis, diagnostics, CancellationToken.None); + + var list = codeActions.Select(c => c.title).ToList(); + var zipList = Enumerable.Range(0, list.Count).Zip(list); + + var locallyImportedPrefix = Resources.ImportLocally.Substring(0, Resources.ImportLocally.IndexOf("'")); + var maxIndexOfTopAddImports = zipList.Where(t => !t.Second.StartsWith(locallyImportedPrefix)).Max(t => t.First); + var minIndexOfLocalAddImports = zipList.Where(t => t.Second.StartsWith(locallyImportedPrefix)).Min(t => t.First); + + maxIndexOfTopAddImports.Should().BeLessThan(minIndexOfLocalAddImports); + } + + [TestMethod, Priority(0)] + public async Task PreserveComment() { + await TestCodeActionAsync( + @"{|insertionSpan:from os import pathconf|} # test + +{|diagnostic:path|}", + title: "from os import path, pathconf", + newText: "from os import path, pathconf"); + } + + [TestMethod, Priority(0)] + public async Task MemberSymbol() { + await TestCodeActionAsync( + @"from os import path +{|insertionSpan:|} +{|diagnostic:socket|}", + title: "from socket import socket", + newText: "from socket import socket" + Environment.NewLine); + } + + [TestMethod, Priority(0)] + public async Task NoMemberSymbol() { + var markup = @"{|insertionSpan:|}{|diagnostic:socket|}"; + + var (analysis, codeActions, insertionSpan) = await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + + codeActions.Select(c => c.title).Should().NotContain("from socket import socket"); + + var title = "import socket"; + var codeAction = codeActions.Single(c => c.title == title); + var newText = "import socket" + Environment.NewLine + Environment.NewLine; + TestCodeAction(analysis.Document.Uri, codeAction, title, insertionSpan, newText); + } + + [TestMethod, Priority(0)] + public async Task SymbolOrdering() { + var markup = @"from os import path +{|insertionSpan:|} +{|diagnostic:socket|}"; + + var (analysis, codeActions, insertionSpan) = await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + + var list = codeActions.Select(c => c.title).ToList(); + var zipList = Enumerable.Range(0, list.Count).Zip(list); + + var maxIndexOfPublicSymbol = zipList.Where(t => !t.Second.StartsWith("from _")).Max(t => t.First); + var minIndexOfPrivateSymbol = zipList.Where(t => t.Second.StartsWith("from _")).Min(t => t.First); + + maxIndexOfPublicSymbol.Should().BeLessThan(minIndexOfPrivateSymbol); + } + + [TestMethod, Priority(0)] + public async Task SymbolOrdering2() { + var markup = @"from os import path +{|insertionSpan:|} +{|diagnostic:join|}"; + + var (analysis, codeActions, insertionSpan) = await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics, enableIndexManager: true); + + var list = codeActions.Select(c => c.title).ToList(); + var zipList = Enumerable.Range(0, list.Count).Zip(list); + + var sourceDeclIndex = zipList.First(t => t.Second == "from posixpath import join").First; + var importedMemberIndex = zipList.First(t => t.Second == "from os.path import join").First; + var restIndex = zipList.First(t => t.Second == "from ntpath import join").First; + + sourceDeclIndex.Should().BeLessThan(importedMemberIndex); + importedMemberIndex.Should().BeLessThan(restIndex); + } + + [TestMethod, Priority(0)] + public async Task SymbolOrdering3() { + var markup = @"{|insertionSpan:|}{|diagnostic:pd|}"; + + MarkupUtils.GetNamedSpans(markup, out var code, out var spans); + + // get main analysis and add mock modules + var analysis = await GetAnalysisAsync(code); + + await GetAnalysisAsync("", analysis.ExpressionEvaluator.Services, modulePath: TestData.GetTestSpecificPath("pandas.py")); + await GetAnalysisAsync("", analysis.ExpressionEvaluator.Services, modulePath: TestData.GetTestSpecificPath("pd.py")); + + // calculate actions + var diagnosticSpan = spans["diagnostic"].First().ToSourceSpan(analysis.Ast); + var diagnostics = GetDiagnostics(analysis, diagnosticSpan, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + var codeActions = await new CodeActionSource(analysis.ExpressionEvaluator.Services).GetCodeActionsAsync(analysis, diagnostics, CancellationToken.None); + + var list = codeActions.Select(c => c.title).ToList(); + var zipList = Enumerable.Range(0, list.Count).Zip(list); + + var pandasIndex = zipList.First(t => t.Second == "import pandas as pd").First; + var pdIndex = zipList.First(t => t.Second == "import pd").First; + + pandasIndex.Should().BeLessThan(pdIndex); + } + + [TestMethod, Priority(0)] + public async Task ModuleNotReachableFromUserDocument() { + await TestCodeActionAsync( + @"{|insertionSpan:|}{|diagnostic:path|}", + title: "from os import path", + newText: "from os import path" + Environment.NewLine + Environment.NewLine, + enableIndexManager: true); + } + + [TestMethod, Priority(0)] + public async Task SuggestAbbreviationForKnownModule() { + await TestCodeActionAsync( + @"{|insertionSpan:|}{|diagnostic:pandas|}", + title: "import pandas as pd", + newText: "import pandas as pd" + Environment.NewLine + Environment.NewLine, + abbreviation: "pd", + relativePaths: "pandas.py"); + } + + [TestMethod, Priority(0)] + public async Task SuggestAbbreviationForKnownModule2() { + await TestCodeActionAsync( + @"{|insertionSpan:|}{|diagnostic:pyplot|}", + title: "from matplotlib import pyplot as plt", + newText: "from matplotlib import pyplot as plt" + Environment.NewLine + Environment.NewLine, + abbreviation: "plt", + relativePaths: @"matplotlib\pyplot.py"); + } + + [TestMethod, Priority(0)] + public async Task SuggestAbbreviationForKnownModule3() { + var markup = @" +{|insertionSpan:from matplotlib import test|} +{|diagnostic:pyplot|}"; + + await TestCodeActionAsync( + markup, + title: "from matplotlib import pyplot as plt, test", + newText: "from matplotlib import pyplot as plt, test", + abbreviation: "plt", + relativePaths: new string[] { @"matplotlib\pyplot.py", @"matplotlib\test.py" }); + } + + [TestMethod, Priority(0)] + public async Task SuggestReverseAbbreviationForKnownModule() { + await TestCodeActionAsync( + @"{|insertionSpan:|}{|diagnostic:pd|}", + title: "import pandas as pd", + newText: "import pandas as pd" + Environment.NewLine + Environment.NewLine, + abbreviation: "pd", + relativePaths: "pandas.py"); + } + + [TestMethod, Priority(0)] + public async Task SuggestReverseAbbreviationForKnownModule2() { + await TestCodeActionAsync( + @"{|insertionSpan:|}{|diagnostic:plt|}", + title: "from matplotlib import pyplot as plt", + newText: "from matplotlib import pyplot as plt" + Environment.NewLine + Environment.NewLine, + abbreviation: "plt", + relativePaths: @"matplotlib\pyplot.py"); + } + + [TestMethod, Priority(0)] + public async Task SuggestReverseAbbreviationForKnownModule3() { + var markup = @" +{|insertionSpan:from matplotlib import test|} +{|diagnostic:plt|}"; + + await TestCodeActionAsync( + markup, + title: "from matplotlib import pyplot as plt, test", + newText: "from matplotlib import pyplot as plt, test", + abbreviation: "plt", + relativePaths: new string[] { @"matplotlib\pyplot.py", @"matplotlib\test.py" }); + } + + [TestMethod, Priority(0)] + public async Task AbbreviationConflict() { + var markup = @"{|insertionSpan:|}pd = 1 + +{|diagnostic:pandas|}"; + + await TestCodeActionAsync( + markup, + title: "import pandas as pd1", + newText: "import pandas as pd1" + Environment.NewLine + Environment.NewLine, + abbreviation: "pd1", + relativePaths: "pandas.py"); + } + + [TestMethod, Priority(0)] + public async Task AbbreviationConflict2() { + var markup = @"{|insertionSpan:|}{|diagnostic:pandas|} + +def Method(): + pd = 1"; + + await TestCodeActionAsync( + markup, + title: "import pandas as pd1", + newText: "import pandas as pd1" + Environment.NewLine + Environment.NewLine, + abbreviation: "pd1", + relativePaths: "pandas.py"); + } + + [TestMethod, Priority(0)] + public async Task ContextBasedSuggestion() { + var markup = + @"from os import path +{|insertionSpan:|} +{|diagnostic:socket|}()"; + + var (analysis, codeActions, insertionSpan) = + await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + + codeActions.Should().NotContain(c => c.title == "import socket"); + + var title = "from socket import socket"; + var newText = "from socket import socket" + Environment.NewLine; + + var codeAction = codeActions.Single(c => c.title == title); + TestCodeAction(analysis.Document.Uri, codeAction, title, insertionSpan, newText); + } + + [TestMethod, Priority(0)] + public async Task ValidToBeUsedInImport() { + await TestCodeActionAsync( + @"from os import path +{|insertionSpan:|} +{|diagnostic:join|}", + title: "from os.path import join", + newText: "from os.path import join" + Environment.NewLine); + } + + private async Task TestCodeActionAsync(string markup, string title, string newText, bool enableIndexManager = false) { + var (analysis, codeActions, insertionSpan) = + await GetAnalysisAndCodeActionsAndSpanAsync(markup, MissingImportCodeActionProvider.Instance.FixableDiagnostics, enableIndexManager); + + var codeAction = codeActions.Single(c => c.title == title); + TestCodeAction(analysis.Document.Uri, codeAction, title, insertionSpan, newText); + } + + private async Task<(IDocumentAnalysis analysis, CodeAction[] diagnostics, SourceSpan insertionSpan)> GetAnalysisAndCodeActionsAndSpanAsync( + string markup, IEnumerable codes, bool enableIndexManager = false) { + MarkupUtils.GetNamedSpans(markup, out var code, out var spans); + + var analysis = await GetAnalysisAsync(code); + + if (enableIndexManager) { + var serviceManager = (IServiceManager)analysis.ExpressionEvaluator.Services; + var indexManager = new IndexManager( + serviceManager.GetService(), + analysis.Document.Interpreter.LanguageVersion, + rootPath: null, + Array.Empty(), + Array.Empty(), + serviceManager.GetService()); + + // make sure index is done + await indexManager.IndexWorkspace(analysis.Document.Interpreter.ModuleResolution.CurrentPathResolver); + + serviceManager.AddService(indexManager); + } + + var insertionSpan = spans["insertionSpan"].First().ToSourceSpan(analysis.Ast); + + var diagnostics = GetDiagnostics(analysis, spans["diagnostic"].First().ToSourceSpan(analysis.Ast), codes); + var codeActions = await new CodeActionSource(analysis.ExpressionEvaluator.Services).GetCodeActionsAsync(analysis, diagnostics, CancellationToken.None); + return (analysis, codeActions.ToArray(), insertionSpan); + } + + private static void TestCodeAction(Uri uri, CodeAction codeAction, string title, Core.Text.Range insertedSpan, string newText) { + codeAction.title.Should().Be(title); + codeAction.edit.changes.Should().HaveCount(1); + + var edit = codeAction.edit.changes[uri]; + edit.Single().range.Should().Be(insertedSpan); + edit.Single().newText.Should().Be(newText); + } + + private static Diagnostic[] GetDiagnostics(IDocumentAnalysis analysis, SourceSpan span, IEnumerable codes) { + var analyzer = analysis.ExpressionEvaluator.Services.GetService(); + return analyzer.LintModule(analysis.Document) + .Where(d => d.SourceSpan == span && codes.Any(e => string.Equals(e, d.ErrorCode))) + .Select(d => d.ToDiagnostic()) + .ToArray(); + } + + private async Task TestCodeActionAsync(string markup, string title, string newText, string abbreviation, params string[] relativePaths) { + MarkupUtils.GetNamedSpans(markup, out var code, out var spans); + + // get main analysis and add mock modules + var analysis = await GetAnalysisAsync(code); + + foreach (var relativePath in relativePaths) { + await GetAnalysisAsync("", analysis.ExpressionEvaluator.Services, modulePath: TestData.GetTestSpecificPath(relativePath)); + } + + // calculate actions + var diagnosticSpan = spans["diagnostic"].First().ToSourceSpan(analysis.Ast); + var diagnostics = GetDiagnostics(analysis, diagnosticSpan, MissingImportCodeActionProvider.Instance.FixableDiagnostics); + var codeActions = await new CodeActionSource(analysis.ExpressionEvaluator.Services).GetCodeActionsAsync(analysis, diagnostics, CancellationToken.None); + + // verify results + var codeAction = codeActions.Single(c => c.title == title); + codeAction.edit.changes.Should().HaveCount(1); + + var edits = codeAction.edit.changes[analysis.Document.Uri]; + edits.Should().HaveCount(2); + + var invocationEdit = edits.Single(e => e.newText == abbreviation); + invocationEdit.range.Should().Be(diagnosticSpan); + + var insertEdit = edits.Single(e => e.newText == newText); + insertEdit.range.Should().Be(spans["insertionSpan"].First().ToSourceSpan(analysis.Ast)); + } + } +} diff --git a/src/LanguageServer/Test/UniqueNameGeneratorTests.cs b/src/LanguageServer/Test/UniqueNameGeneratorTests.cs new file mode 100644 index 000000000..4233ae81f --- /dev/null +++ b/src/LanguageServer/Test/UniqueNameGeneratorTests.cs @@ -0,0 +1,151 @@ +// Copyright(c) Microsoft Corporation +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the License); you may not use +// this file except in compliance with the License. You may obtain a copy of the +// License at http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS +// OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY +// IMPLIED WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// +// See the Apache Version 2.0 License for specific language governing +// permissions and limitations under the License. + +using System.Reflection.PortableExecutable; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Python.Analysis; +using Microsoft.Python.LanguageServer.Tests.FluentAssertions; +using Microsoft.Python.LanguageServer.Utilities; +using Microsoft.Python.Parsing.Tests; +using Microsoft.Python.UnitTests.Core; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using TestUtilities; + +namespace Microsoft.Python.LanguageServer.Tests { + [TestClass] + public class UniqueNameGeneratorTests : LanguageServerTestBase { + public TestContext TestContext { get; set; } + + [TestInitialize] + public void TestInitialize() + => TestEnvironmentImpl.TestInitialize($"{TestContext.FullyQualifiedTestClassName}.{TestContext.TestName}"); + + [TestCleanup] + public void Cleanup() => TestEnvironmentImpl.TestCleanup(); + + [TestMethod, Priority(0)] + public async Task NoConflict() { + MarkupUtils.GetPosition(@"$$", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name"); + Test(analysis, "name", "name"); + } + + [TestMethod, Priority(0)] + public async Task Conflict_TopLevel() { + MarkupUtils.GetPosition(@"$$ + +name = 1", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name1"); + Test(analysis, "name", "name1"); + } + + [TestMethod, Priority(0)] + public async Task Conflict_TopLevel2() { + MarkupUtils.GetPosition(@"$$ + +class name: + pass", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name1"); + Test(analysis, "name", "name1"); + } + + [TestMethod, Priority(0)] + public async Task Conflict_Function() { + MarkupUtils.GetPosition(@"def Test(): + $$ + +name = 1", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name1"); + Test(analysis, "name", "name1"); + } + + [TestMethod, Priority(0)] + public async Task Conflict_Function2() { + MarkupUtils.GetPosition(@"def Test(): + name = 1 + $$ + pass", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name1"); + Test(analysis, "name", "name1"); + } + + [TestMethod, Priority(0)] + public async Task Conflict_Function3() { + MarkupUtils.GetPosition(@"def Test(): + name = 1 + +def Test2(): + $$ + pass", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name"); + Test(analysis, "name", "name1"); + } + + [TestMethod, Priority(0)] + public async Task Conflict_TopLevel3() { + MarkupUtils.GetPosition(@"def Test(): + name = 1 + +$$", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name"); + Test(analysis, "name", "name1"); + } + + [TestMethod, Priority(0)] + public async Task MultipleConflicts() { + MarkupUtils.GetPosition(@" +name1 = 1 + +class name3: + name2 = 1 + +def Test(): + name = 1 + + def name4(): + pass + +$$", out var code, out int position); + + var analysis = await GetAnalysisAsync(code); + Test(analysis, position, "name", "name"); + Test(analysis, "name", "name5"); + } + + private static void Test(IDocumentAnalysis analysis, int position, string name, string expected) { + var actual = UniqueNameGenerator.Generate(analysis, position, name); + actual.Should().Be(expected); + } + + private static void Test(IDocumentAnalysis analysis, string name, string expected) { + Test(analysis, position: -1, name, expected); + } + } +} diff --git a/src/Parsing/Impl/Ast/SourceLocationExtensions.cs b/src/Parsing/Impl/Ast/SourceLocationExtensions.cs index 9377fe70f..157c9b127 100644 --- a/src/Parsing/Impl/Ast/SourceLocationExtensions.cs +++ b/src/Parsing/Impl/Ast/SourceLocationExtensions.cs @@ -19,17 +19,34 @@ namespace Microsoft.Python.Parsing.Ast { public static class SourceLocationExtensions { public static int ToIndex(this SourceLocation location, ILocationConverter lc) => lc.LocationToIndex(location); + + public static SourceLocation ToSourceLocation(this Position position, ILocationConverter lc = null) { + var location = new SourceLocation(position.line + 1, position.character + 1); + if (lc == null) { + return location; + } + + return new SourceLocation(lc.LocationToIndex(location), location.Line, location.Column); + } + } + + public static class RangeExtensions { + public static IndexSpan ToIndexSpan(this Range range, ILocationConverter lc) + => IndexSpan.FromBounds(lc.LocationToIndex(range.start), lc.LocationToIndex(range.end)); + public static SourceSpan ToSourceSpan(this Range range, ILocationConverter lc = null) + => new SourceSpan(range.start.ToSourceLocation(lc), range.end.ToSourceLocation(lc)); } public static class SourceSpanExtensions { public static IndexSpan ToIndexSpan(this SourceSpan span, ILocationConverter lc) => IndexSpan.FromBounds(lc.LocationToIndex(span.Start), lc.LocationToIndex(span.End)); - public static IndexSpan ToIndexSpan(this Range range, ILocationConverter lc) - => IndexSpan.FromBounds(lc.LocationToIndex(range.start), lc.LocationToIndex(range.end)); } public static class IndexSpanExtensions { public static SourceSpan ToSourceSpan(this IndexSpan span, ILocationConverter lc) => lc != null ? new SourceSpan(lc.IndexToLocation(span.Start), lc.IndexToLocation(span.End)) : default; + public static bool Contains(this IndexSpan span, IndexSpan other) { + return span.Start <= other.Start && other.End <= span.End; + } } }