Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 237a18c

Browse files
authored
__all__ evaluation (microsoft#863)
Fixes microsoft#620. Fixes microsoft#619. This adds support for concatenating lists using `+`, doing `+=` on `__all__`, and calling `append` and `extend` on `__all__`. If something goes wrong (an unsupported operation on `__all__` or some unsupported value), then the old behavior continues to be used. I don't track uses of `__all__` indirectly (i.e. passing `__all__` to something that modifies it), only direct actions. If `__all__` is in a more complicated lvar (like `__all__, foo = ...`), then it is ignored. This can be improved later on when we fix up our multiple assignment issues. This works well for Django models (see microsoft#620), but `numpy`'s import cycles prevent this from having an effect, so the old behavior will be used. ~Tests are WIP.~ I'll need to rebase/merge master when the refs per gets merged.
1 parent 838c778 commit 237a18c

File tree

4 files changed

+290
-1
lines changed

4 files changed

+290
-1
lines changed

src/Analysis/Ast/Impl/Analyzer/Evaluation/ExpressionEval.Operators.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
// See the Apache Version 2.0 License for specific language governing
1414
// permissions and limitations under the License.
1515

16+
using System;
17+
using System.Collections.Generic;
1618
using Microsoft.Python.Analysis.Modules;
1719
using Microsoft.Python.Analysis.Types;
20+
using Microsoft.Python.Analysis.Types.Collections;
1821
using Microsoft.Python.Analysis.Values;
1922
using Microsoft.Python.Parsing;
2023
using Microsoft.Python.Parsing.Ast;
@@ -123,6 +126,16 @@ private IMember GetValueFromBinaryOp(Expression expr) {
123126
return left;
124127
}
125128

129+
if (binop.Operator == PythonOperator.Add
130+
&& left.GetPythonType()?.TypeId == BuiltinTypeId.List
131+
&& right.GetPythonType()?.TypeId == BuiltinTypeId.List) {
132+
133+
var leftVar = GetValueFromExpression(binop.Left) as IPythonCollection;
134+
var rightVar = GetValueFromExpression(binop.Right) as IPythonCollection;
135+
136+
return PythonCollectionType.CreateConcatenatedList(Module.Interpreter, GetLoc(expr), leftVar?.Contents, rightVar?.Contents);
137+
}
138+
126139
return left.IsUnknown() ? right : left;
127140
}
128141
}

src/Analysis/Ast/Impl/Analyzer/ModuleWalker.cs

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
// See the Apache Version 2.0 License for specific language governing
1414
// permissions and limitations under the License.
1515

16+
using System;
17+
using System.Collections.Generic;
1618
using System.Diagnostics;
1719
using System.Linq;
1820
using Microsoft.Python.Analysis.Analyzer.Evaluation;
1921
using Microsoft.Python.Analysis.Documents;
2022
using Microsoft.Python.Analysis.Modules;
2123
using Microsoft.Python.Analysis.Types;
24+
using Microsoft.Python.Analysis.Types.Collections;
2225
using Microsoft.Python.Analysis.Values;
2326
using Microsoft.Python.Core;
2427
using Microsoft.Python.Core.Collections;
@@ -33,6 +36,7 @@ internal class ModuleWalker : AnalysisWalker {
3336

3437
// A hack to use __all__ export in the most simple case.
3538
private int _allReferencesCount;
39+
private bool _allIsUsable = true;
3640

3741
public ModuleWalker(IServiceContainer services, IPythonModule module, PythonAst ast)
3842
: base(new ExpressionEval(services, module, ast)) {
@@ -47,6 +51,103 @@ public override bool Walk(NameExpression node) {
4751
return base.Walk(node);
4852
}
4953

54+
public override bool Walk(AugmentedAssignStatement node) {
55+
HandleAugmentedAllAssign(node);
56+
return base.Walk(node);
57+
}
58+
59+
public override bool Walk(CallExpression node) {
60+
HandleAllAppendExtend(node);
61+
return base.Walk(node);
62+
}
63+
64+
private void HandleAugmentedAllAssign(AugmentedAssignStatement node) {
65+
if (!IsHandleableAll(node.Left)) {
66+
return;
67+
}
68+
69+
if (node.Right is ErrorExpression) {
70+
return;
71+
}
72+
73+
if (node.Operator != Parsing.PythonOperator.Add) {
74+
_allIsUsable = false;
75+
return;
76+
}
77+
78+
var rightVar = Eval.GetValueFromExpression(node.Right);
79+
var rightContents = (rightVar as IPythonCollection)?.Contents;
80+
81+
if (rightContents == null) {
82+
_allIsUsable = false;
83+
return;
84+
}
85+
86+
ExtendAll(node.Left, rightContents);
87+
}
88+
89+
private void HandleAllAppendExtend(CallExpression node) {
90+
if (!(node.Target is MemberExpression me)) {
91+
return;
92+
}
93+
94+
if (!IsHandleableAll(me.Target)) {
95+
return;
96+
}
97+
98+
if (node.Args.Count == 0) {
99+
return;
100+
}
101+
102+
IReadOnlyList<IMember> contents = null;
103+
var v = Eval.GetValueFromExpression(node.Args[0].Expression);
104+
if (v == null) {
105+
_allIsUsable = false;
106+
return;
107+
}
108+
109+
switch (me.Name) {
110+
case "append":
111+
contents = new List<IMember>() { v };
112+
break;
113+
case "extend":
114+
contents = (v as IPythonCollection)?.Contents;
115+
break;
116+
}
117+
118+
if (contents == null) {
119+
_allIsUsable = false;
120+
return;
121+
}
122+
123+
ExtendAll(node, contents);
124+
}
125+
126+
private void ExtendAll(Node declNode, IReadOnlyList<IMember> values) {
127+
Eval.LookupNameInScopes(AllVariableName, out var scope, LookupOptions.Normal);
128+
if (scope == null) {
129+
return;
130+
}
131+
132+
var loc = Eval.GetLoc(declNode);
133+
134+
var allContents = (scope.Variables[AllVariableName].Value as IPythonCollection)?.Contents;
135+
136+
var list = PythonCollectionType.CreateConcatenatedList(Module.Interpreter, loc, allContents, values);
137+
var source = list.IsGeneric() ? VariableSource.Generic : VariableSource.Declaration;
138+
139+
Eval.DeclareVariable(AllVariableName, list, source, loc);
140+
}
141+
142+
private bool IsHandleableAll(Node node) {
143+
// TODO: handle more complicated lvars
144+
if (!(node is NameExpression ne)) {
145+
return false;
146+
}
147+
148+
return Eval.CurrentScope == Eval.GlobalScope && ne.Name == AllVariableName;
149+
}
150+
50151
public override bool Walk(PythonAst node) {
51152
Check.InvalidOperation(() => Ast == node, "walking wrong AST");
52153

@@ -98,7 +199,7 @@ public void Complete() {
98199
SymbolTable.ReplacedByStubs.Clear();
99200
MergeStub();
100201

101-
if (_allReferencesCount == 1 && GlobalScope.Variables.TryGetVariable(AllVariableName, out var variable) && variable?.Value is IPythonCollection collection) {
202+
if (_allIsUsable && _allReferencesCount >= 1 && GlobalScope.Variables.TryGetVariable(AllVariableName, out var variable) && variable?.Value is IPythonCollection collection) {
102203
ExportedMemberNames = collection.Contents
103204
.OfType<IPythonConstant>()
104205
.Select(c => c.GetString())

src/Analysis/Ast/Impl/Types/Collections/PythonCollectionType.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using System.Linq;
1919
using Microsoft.Python.Analysis.Values;
2020
using Microsoft.Python.Analysis.Values.Collections;
21+
using Microsoft.Python.Core;
2122

2223
namespace Microsoft.Python.Analysis.Types.Collections {
2324
/// <summary>
@@ -99,6 +100,11 @@ public static IPythonCollection CreateList(IPythonInterpreter interpreter, Locat
99100
return new PythonCollection(collectionType, location, contents, flatten);
100101
}
101102

103+
public static IPythonCollection CreateConcatenatedList(IPythonInterpreter interpreter, LocationInfo location, params IReadOnlyList<IMember>[] manyContents) {
104+
var contents = manyContents?.ExcludeDefault().SelectMany().ToList() ?? new List<IMember>();
105+
return CreateList(interpreter, location, contents);
106+
}
107+
102108
public static IPythonCollection CreateTuple(IPythonInterpreter interpreter, LocationInfo location, IReadOnlyList<IMember> contents) {
103109
var collectionType = new PythonCollectionType(null, BuiltinTypeId.Tuple, interpreter, false);
104110
return new PythonCollection(collectionType, location, contents);

src/LanguageServer/Test/ImportsTests.cs

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,5 +459,174 @@ public async Task FromImport_ModuleAffectsPackage(string appCodeImport) {
459459
comps = cs.GetCompletions(analysis, new SourceLocation(2, 21));
460460
comps.Should().HaveLabels("X");
461461
}
462+
463+
[TestMethod, Priority(0)]
464+
public async Task AllSimple() {
465+
var module1Code = @"
466+
class A:
467+
def foo(self):
468+
pass
469+
pass
470+
471+
class B:
472+
def bar(self):
473+
pass
474+
pass
475+
476+
__all__ = ['A']
477+
";
478+
479+
var appCode = @"
480+
from module1 import *
481+
482+
A().
483+
B().
484+
";
485+
486+
var module1Uri = TestData.GetTestSpecificUri("module1.py");
487+
var appUri = TestData.GetTestSpecificUri("app.py");
488+
489+
var root = Path.GetDirectoryName(appUri.AbsolutePath);
490+
await CreateServicesAsync(root, PythonVersions.LatestAvailable3X);
491+
var rdt = Services.GetService<IRunningDocumentTable>();
492+
var analyzer = Services.GetService<IPythonAnalyzer>();
493+
494+
rdt.OpenDocument(module1Uri, module1Code);
495+
496+
var app = rdt.OpenDocument(appUri, appCode);
497+
await analyzer.WaitForCompleteAnalysisAsync();
498+
var analysis = await app.GetAnalysisAsync(-1);
499+
500+
var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion);
501+
var comps = cs.GetCompletions(analysis, new SourceLocation(4, 5));
502+
comps.Should().HaveLabels("foo");
503+
504+
comps = cs.GetCompletions(analysis, new SourceLocation(5, 5));
505+
comps.Should().NotContainLabels("bar");
506+
}
507+
508+
[DataRow(@"
509+
other = ['B']
510+
__all__ = ['A'] + other")]
511+
[DataRow(@"
512+
other = ['B']
513+
__all__ = ['A']
514+
__all__ += other")]
515+
[DataRow(@"
516+
other = ['B']
517+
__all__ = ['A']
518+
__all__.extend(other)")]
519+
[DataRow(@"
520+
__all__ = ['A']
521+
__all__.append('B')")]
522+
[DataTestMethod, Priority(0)]
523+
public async Task AllComplex(string allCode) {
524+
var module1Code = @"
525+
class A:
526+
def foo(self):
527+
pass
528+
pass
529+
530+
class B:
531+
def bar(self):
532+
pass
533+
pass
534+
535+
class C:
536+
def baz(self):
537+
pass
538+
pass
539+
" + allCode;
540+
541+
var appCode = @"
542+
from module1 import *
543+
544+
A().
545+
B().
546+
C().
547+
";
548+
549+
var module1Uri = TestData.GetTestSpecificUri("module1.py");
550+
var appUri = TestData.GetTestSpecificUri("app.py");
551+
552+
var root = Path.GetDirectoryName(appUri.AbsolutePath);
553+
await CreateServicesAsync(root, PythonVersions.LatestAvailable3X);
554+
var rdt = Services.GetService<IRunningDocumentTable>();
555+
var analyzer = Services.GetService<IPythonAnalyzer>();
556+
557+
rdt.OpenDocument(module1Uri, module1Code);
558+
559+
var app = rdt.OpenDocument(appUri, appCode);
560+
await analyzer.WaitForCompleteAnalysisAsync();
561+
var analysis = await app.GetAnalysisAsync(-1);
562+
563+
var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion);
564+
var comps = cs.GetCompletions(analysis, new SourceLocation(4, 5));
565+
comps.Should().HaveLabels("foo");
566+
567+
comps = cs.GetCompletions(analysis, new SourceLocation(5, 5));
568+
comps.Should().HaveLabels("bar");
569+
570+
comps = cs.GetCompletions(analysis, new SourceLocation(6, 5));
571+
comps.Should().NotContainLabels("baz");
572+
}
573+
574+
[DataRow(@"
575+
__all__ = ['A']
576+
__all__.something(A)")]
577+
[DataRow(@"
578+
__all__ = ['A']
579+
__all__ *= ['B']")]
580+
[DataRow(@"
581+
__all__ = ['A']
582+
__all__ += 1234")]
583+
[DataRow(@"
584+
__all__ = ['A']
585+
__all__.extend(123)")]
586+
[DataRow(@"
587+
__all__ = ['A']
588+
__all__.extend(nothing)")]
589+
[DataTestMethod, Priority(0)]
590+
public async Task AllUnsupported(string allCode) {
591+
var module1Code = @"
592+
class A:
593+
def foo(self):
594+
pass
595+
pass
596+
597+
class B:
598+
def bar(self):
599+
pass
600+
pass
601+
" + allCode;
602+
603+
var appCode = @"
604+
from module1 import *
605+
606+
A().
607+
B().
608+
";
609+
610+
var module1Uri = TestData.GetTestSpecificUri("module1.py");
611+
var appUri = TestData.GetTestSpecificUri("app.py");
612+
613+
var root = Path.GetDirectoryName(appUri.AbsolutePath);
614+
await CreateServicesAsync(root, PythonVersions.LatestAvailable3X);
615+
var rdt = Services.GetService<IRunningDocumentTable>();
616+
var analyzer = Services.GetService<IPythonAnalyzer>();
617+
618+
rdt.OpenDocument(module1Uri, module1Code);
619+
620+
var app = rdt.OpenDocument(appUri, appCode);
621+
await analyzer.WaitForCompleteAnalysisAsync();
622+
var analysis = await app.GetAnalysisAsync(-1);
623+
624+
var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion);
625+
var comps = cs.GetCompletions(analysis, new SourceLocation(4, 5));
626+
comps.Should().HaveLabels("foo");
627+
628+
comps = cs.GetCompletions(analysis, new SourceLocation(5, 5));
629+
comps.Should().HaveLabels("bar");
630+
}
462631
}
463632
}

0 commit comments

Comments
 (0)