Skip to content

Commit d733047

Browse files
authored
Evaluate first comprehension iterator in outer scope (#1130)
* Evaluate first comprehension iterator in outer scope * Add tests * Update rewriting code path * Add dome more tests * Revert "Workaround for #809 (#810)" This reverts commit 14f34c9. * Guard against invalid comprehension iterators
1 parent 8b85383 commit d733047

7 files changed

Lines changed: 173 additions & 15 deletions

File tree

Src/IronPython/Compiler/Ast/Comprehension.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ public sealed class ListComprehension : Comprehension {
7575
private readonly ComprehensionIterator[] _iterators;
7676

7777
public ListComprehension(Expression item, ComprehensionIterator[] iterators) {
78+
if (iterators is null || iterators.Length < 1) throw new ArgumentException("comprehension with no generators");
79+
if (iterators[0] is not ComprehensionFor) throw new ArgumentException("comprehension with invalid generator");
80+
7881
Item = item;
7982
_iterators = iterators;
8083
Scope = new ComprehensionScope(this);
@@ -121,6 +124,9 @@ public sealed class SetComprehension : Comprehension {
121124
private readonly ComprehensionIterator[] _iterators;
122125

123126
public SetComprehension(Expression item, ComprehensionIterator[] iterators) {
127+
if (iterators is null || iterators.Length < 1) throw new ArgumentException("comprehension with no generators");
128+
if (iterators[0] is not ComprehensionFor) throw new ArgumentException("comprehension with invalid generator");
129+
124130
Item = item;
125131
_iterators = iterators;
126132
Scope = new ComprehensionScope(this);
@@ -167,6 +173,9 @@ public sealed class DictionaryComprehension : Comprehension {
167173
private readonly ComprehensionIterator[] _iterators;
168174

169175
public DictionaryComprehension(Expression key, Expression value, ComprehensionIterator[] iterators) {
176+
if (iterators is null || iterators.Length < 1) throw new ArgumentException("comprehension with no generators");
177+
if (iterators[0] is not ComprehensionFor) throw new ArgumentException("comprehension with invalid generator");
178+
170179
Key = key;
171180
Value = value;
172181
_iterators = iterators;

Src/IronPython/Compiler/Ast/PythonAst.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,7 @@ internal PythonAst MakeLookupCode() {
720720

721721
internal class LookupVisitor : MSAst.ExpressionVisitor {
722722
private readonly MSAst.Expression _globalContext;
723+
private readonly Dictionary<MSAst.Expression, ScopeStatement> _outerComprehensionScopes = new();
723724
private ScopeStatement _curScope;
724725

725726
public LookupVisitor(PythonAst ast, MSAst.Expression globalContext) {
@@ -739,6 +740,12 @@ protected override MSAst.Expression VisitExtension(MSAst.Expression node) {
739740
return PythonAst._globalContext;
740741
}
741742

743+
// outer comprehension iterable is visited in outer comprehension scope
744+
if (_outerComprehensionScopes.TryGetValue(node, out ScopeStatement outerComprehensionScope)) {
745+
_outerComprehensionScopes.Remove(node);
746+
return VisitExtensionInScope(node, outerComprehensionScope);
747+
}
748+
742749
// we need to re-write nested scopes
743750
if (node is ScopeStatement scope) {
744751
return base.VisitExtension(VisitScope(scope));
@@ -815,14 +822,27 @@ private MSAst.Expression VisitComprehension(Comprehension comprehension) {
815822

816823
ScopeStatement prevScope = _curScope;
817824
try {
818-
// rewrite the comprehension in a new scope
825+
// mark the first (outermost) "for" iterator for rewrite in the current scope
826+
_outerComprehensionScopes[((ComprehensionFor)comprehension.Iterators[0]).List] = _curScope;
827+
828+
// rewrite the rest of comprehension in the new scope
819829
_curScope = newScope;
820830

821831
return base.VisitExtension(newComprehension);
822832
} finally {
823833
_curScope = prevScope;
824834
}
825835
}
836+
837+
private MSAst.Expression VisitExtensionInScope(MSAst.Expression node, ScopeStatement scope) {
838+
ScopeStatement prevScope = _curScope;
839+
try {
840+
_curScope = scope;
841+
return VisitExtension(node); // not base.VisitExtension
842+
} finally {
843+
_curScope = prevScope;
844+
}
845+
}
826846
}
827847

828848
#endregion

Src/IronPython/Compiler/Ast/PythonNameBinder.cs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ public override bool Walk(ClassDefinition node) {
216216
dec.Walk(this);
217217
}
218218
}
219-
219+
220220
PushScope(node);
221221

222222
node.ModuleNameVariable = _globalScope.EnsureGlobalVariable("__name__");
@@ -248,10 +248,30 @@ public override bool Walk(DelStatement node) {
248248
return true;
249249
}
250250

251-
public override bool Walk(ListComprehension node) {
251+
// Comprehensions
252+
253+
private void WalkComprehensionIterators(Comprehension node) {
252254
node.Parent = _currentScope;
255+
256+
// Special walk case: first (outermost) "for" iterator
257+
// See also: PythonAst.LookupVisitor.VisitComprehension(...)
258+
var outermostFor = (ComprehensionFor)node.Iterators[0];
259+
outermostFor.List.Walk(this);
253260
PushScope(node.Scope);
254-
return base.Walk(node);
261+
Walk(outermostFor);
262+
outermostFor.Left.Walk(this);
263+
PostWalk(outermostFor);
264+
265+
// Regular walk cases: remaining iterators/conditionals
266+
for (int i = 1; i < node.Iterators.Count; i++) {
267+
node.Iterators[i].Walk(this);
268+
}
269+
}
270+
271+
public override bool Walk(ListComprehension node) {
272+
WalkComprehensionIterators(node);
273+
node.Item.Walk(this);
274+
return false;
255275
}
256276

257277
public override void PostWalk(ListComprehension node) {
@@ -264,9 +284,9 @@ public override void PostWalk(ListComprehension node) {
264284
}
265285

266286
public override bool Walk(SetComprehension node) {
267-
node.Parent = _currentScope;
268-
PushScope(node.Scope);
269-
return base.Walk(node);
287+
WalkComprehensionIterators(node);
288+
node.Item.Walk(this);
289+
return false;
270290
}
271291

272292
public override void PostWalk(SetComprehension node) {
@@ -279,9 +299,10 @@ public override void PostWalk(SetComprehension node) {
279299
}
280300

281301
public override bool Walk(DictionaryComprehension node) {
282-
node.Parent = _currentScope;
283-
PushScope(node.Scope);
284-
return base.Walk(node);
302+
WalkComprehensionIterators(node);
303+
node.Key.Walk(this);
304+
node.Value.Walk(this);
305+
return false;
285306
}
286307

287308
public override void PostWalk(DictionaryComprehension node) {

Src/StdLib/Lib/pickle.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,8 @@ def __init__(self, value):
177177
MEMOIZE = b'\x94' # store top of the stack in memo
178178
FRAME = b'\x95' # indicate the beginning of a new frame
179179

180-
# _dir is a workaround for https://github.com/IronLanguages/ironpython3/issues/809
181-
_dir = dir()
182-
__all__.extend([x for x in _dir if re.match("[A-Z][A-Z0-9_]+$", x)])
183-
del _dir
180+
__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)])
181+
184182

185183
class _Framer:
186184

Tests/test_dictcomp.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,27 @@ def eval(f, i):
9999
r = {k:eval(lambda i: i+v, k+v) for k in range(2)}
100100
self.assertEqual(r, {0:4, 1:5})
101101

102+
def test_ipy3_gh809(self):
103+
"""https://github.com/IronLanguages/ironpython3/issues/809"""
104+
105+
# iterable is evaluated in the outer scope
106+
self.assertIn('self', {x : None for x in dir()})
107+
108+
# this rule applies recursively to nested comprehensions
109+
self.assertIn('self', {x : None for x in {y : None for y in dir()}})
110+
self.assertIn('self', {x : None for x in {y : None for y in {z : None for z in dir()}}})
111+
112+
# this only applies to the first iterable
113+
# subsequent iterables are evaluated within the comprehension scope
114+
self.assertEqual({(0, 'x') : None}, {(x, y) : None for x in range(1) for y in dir() if not y.startswith('.')}) # (filtering out auxiliary variable staring with a dot, used by CPython)
115+
116+
# also subsequent conditions are evaluated within the comprehension scope
117+
a, b, c, d = range(4)
118+
self.assertTrue(len(dir()) >= 4)
119+
self.assertEqual({}, {None : dir() for x in range(1) if len(dir()) >= 4})
120+
121+
# lambdas create a new scope
122+
self.assertEqual({}, {x : None for x in (lambda: dir())()})
123+
self.assertEqual({None: []}, {None: x() for x in {lambda: dir() : None}})
124+
102125
run_test(__name__)

Tests/test_listcomp.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import unittest
66

7-
from iptest import run_test
7+
from iptest import run_test, is_cli
88

99
class ListCompTest(unittest.TestCase):
1010
def test_positive(self):
@@ -41,4 +41,69 @@ def test_negative(self):
4141
self.assertRaises(NameError, lambda: [(x, z) for x in "iron" if z > x for z in "python" ])
4242
self.assertRaises(NameError, lambda: [(i, j) for i in range(10) if j < 'c' for j in ['a', 'b', 'c'] if i % 3 == 0])
4343

44+
def test_ipy3_gh809(self):
45+
"""https://github.com/IronLanguages/ironpython3/issues/809"""
46+
47+
# iterable is evaluated in the outer scope
48+
self.assertIn('self', [x for x in dir()])
49+
50+
# this rule applies recursively to nested comprehensions
51+
self.assertIn('self', [x for x in [y for y in dir()]])
52+
self.assertIn('self', [x for x in [y for y in [z for z in dir()]]])
53+
54+
# this only applies to the first iterable
55+
# subsequent iterables are evaluated within the comprehension scope
56+
self.assertEqual([(0, 'x')], [(x, y) for x in range(1) for y in dir() if not y.startswith('.')]) # (filtering out auxiliary variable staring with a dot, used by CPython)
57+
58+
# also subsequent conditions are evaluated within the comprehension scope
59+
a, b, c, d = range(4)
60+
self.assertTrue(len(dir()) >= 4)
61+
self.assertEqual([], [dir() for x in range(1) if len(dir()) >= 4])
62+
63+
# subsequent iterables introduce local variables after the first iteration
64+
self.assertEqual([(0, 'x'), (1, 'x'), (1, 'y'), (2, 'x'), (2, 'y')],
65+
[(x, y) for x in range(3)
66+
for y in dir() if not y.startswith('.')])
67+
self.assertEqual([(0, 'x', 'x'), (0, 'x', 'y'),
68+
(1, 'x', 'x'), (1, 'x', 'y'), (1, 'x', 'z'),
69+
(1, 'y', 'x'), (1, 'y', 'y'), (1, 'y', 'z'),
70+
(1, 'z', 'x'), (1, 'z', 'y'), (1, 'z', 'z')],
71+
[(x, y, z) for x in range(2)
72+
for y in dir() if not y.startswith('.')
73+
for z in dir() if not z.startswith('.')])
74+
75+
# lambdas create a new scope
76+
self.assertEqual([], [x for x in (lambda: dir())()])
77+
self.assertEqual([[]], [x() for x in [lambda: dir()]])
78+
79+
# first iterable is captured and subsequent assignments do not change it
80+
self.maxDiff = None
81+
x, y, z = range(1), range(2), range(3)
82+
if is_cli:
83+
_dir = ['x', 'y']
84+
# TODO: should be: _dir = ['x', 'y', 'z']
85+
# See: https://github.com/IronLanguages/ironpython3/issues/1132
86+
else:
87+
_dir = ['.0', 'x', 'y', 'z'] # adds implementation-level variable '.0'
88+
# below, x and first/third y is local, second y and z is from the outer scope
89+
self.assertEqual([(x, y, z, dir()) for x in y for y in z],
90+
[(0, 0, range(3), _dir),
91+
(0, 1, range(3), _dir),
92+
(0, 2, range(3), _dir),
93+
(1, 0, range(3), _dir),
94+
(1, 1, range(3), _dir),
95+
(1, 2, range(3), _dir)])
96+
97+
# mixing scopes
98+
def apply(f, i): return f(i)
99+
x = 2
100+
res = [x for x in apply(lambda i: range(i+x), x)]
101+
self.assertEqual(res, [0, 1, 2, 3])
102+
103+
res = [(x, y) for x in apply(lambda i: range(i+x), x) for y in apply(lambda i: range(i+x//2), x)]
104+
self.assertEqual(res, [(1, 0), (2, 0), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2), (3, 3)])
105+
106+
res = [x for x in [y for y in apply(lambda i: range(i+x), x)]]
107+
self.assertEqual(res, [0, 1, 2, 3])
108+
44109
run_test(__name__)

Tests/test_setcomp.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,26 @@ def eval(f, i):
9999
r = {eval(lambda i: i+v, k+v) for k in range(2)}
100100
self.assertEqual(r, set([4, 5]))
101101

102+
def test_ipy3_gh809(self):
103+
"""https://github.com/IronLanguages/ironpython3/issues/809"""
104+
105+
# iterable is evaluated in the outer scope
106+
self.assertIn('self', {x for x in dir()})
107+
108+
# this rule applies recursively to nested comprehensions
109+
self.assertIn('self', {x for x in {y for y in dir()}})
110+
self.assertIn('self', {x for x in {y for y in {z for z in dir()}}})
111+
112+
# this only applies to the first iterable
113+
# subsequent iterables are evaluated within the comprehension scope
114+
self.assertEqual({(0, 'x')}, {(x, y) for x in range(1) for y in dir() if not y.startswith('.')}) # (filtering out auxiliary variable staring with a dot, used by CPython)
115+
116+
# also subsequent conditions are evaluated within the comprehension scope
117+
a, b, c, d = range(4)
118+
self.assertTrue(len(dir()) >= 4)
119+
self.assertEqual(set(), {dir() for x in range(1) if len(dir()) >= 4})
120+
121+
# lambdas create a new scope
122+
self.assertEqual(set(), {x for x in (lambda: dir())()})
123+
102124
run_test(__name__)

0 commit comments

Comments
 (0)