Skip to content

Commit e1cc110

Browse files
authored
Avoid naming collision during synthesis (#543)
* Adding rename transformer * Also rename local symbol references * More linting
1 parent 7b360e6 commit e1cc110

2 files changed

Lines changed: 214 additions & 2 deletions

File tree

effectful/handlers/llm/evaluation.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import ast
22
import builtins
33
import collections.abc
4+
import copy
45
import inspect
6+
import keyword
57
import linecache
8+
import random
9+
import string
610
import sys
711
import types
812
import typing
@@ -493,6 +497,54 @@ class definitions with proper inheritance, typed attributes, and method stubs.
493497
return nodes
494498

495499

500+
def _generate_unique_name(existing_names: set[str]) -> str:
501+
"""Generate a random valid Python identifier that is not in existing_names.
502+
503+
Produces names like ``_synth_a3f7b2`` that are valid identifiers,
504+
not Python keywords, and not in the given set of existing names.
505+
"""
506+
while True:
507+
suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
508+
candidate = f"_synth_{suffix}"
509+
if (
510+
candidate not in existing_names
511+
and candidate.isidentifier()
512+
and not keyword.iskeyword(candidate)
513+
):
514+
return candidate
515+
516+
517+
class _RenameTransformer(ast.NodeTransformer):
518+
"""Rename function definitions and their references in a module AST.
519+
520+
Given a mapping ``{old_name: new_name}``, renames:
521+
- ``FunctionDef.name`` for matching definitions
522+
- ``ast.Name.id`` references throughout the entire AST
523+
524+
The rename is applied uniformly because it only targets module-level
525+
function definitions that collide with context variable declarations.
526+
Local assignments inside function bodies are in their own scope and
527+
cannot cause the mypy ``[no-redef]`` error, so they need no special
528+
handling.
529+
"""
530+
531+
def __init__(self, rename_map: dict[str, str]):
532+
self.rename_map = rename_map
533+
534+
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
535+
if node.name in self.rename_map:
536+
node.name = self.rename_map[node.name]
537+
self.generic_visit(node)
538+
return node
539+
540+
visit_AsyncFunctionDef = visit_FunctionDef # type: ignore[assignment]
541+
542+
def visit_Name(self, node: ast.Name) -> ast.Name:
543+
if node.id in self.rename_map:
544+
node.id = self.rename_map[node.id]
545+
return node
546+
547+
496548
def mypy_type_check(
497549
module: ast.Module,
498550
ctx: typing.Mapping[str, Any],
@@ -505,6 +557,9 @@ def mypy_type_check(
505557
appends the module body, then a postlude that assigns the last function to a
506558
variable annotated with Callable[expected_params, expected_return]. Runs mypy
507559
on the combined source; raises TypeError with the mypy report on failure.
560+
561+
If the synthesized function name clashes with a name already in the context,
562+
the function is renamed to a unique random identifier for type-checking only.
508563
"""
509564
if not module.body:
510565
raise TypeError("mypy_type_check: module.body is empty")
@@ -527,6 +582,43 @@ def mypy_type_check(
527582
stubs = collect_runtime_type_stubs(ctx)
528583
variables = collect_variable_declarations(ctx)
529584

585+
# Collect names already declared in the type-checking preamble
586+
# (variable declarations and class stubs) that could collide with
587+
# function definitions in the synthesized module.
588+
declared_names = {
589+
stmt.target.id
590+
for stmt in variables
591+
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name)
592+
} | {stmt.name for stmt in stubs if isinstance(stmt, ast.ClassDef)}
593+
594+
# Find all function names in the synthesized module that collide
595+
synthesized_func_names = {
596+
stmt.name
597+
for stmt in module.body
598+
if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef))
599+
}
600+
colliding_names = synthesized_func_names & declared_names
601+
602+
if colliding_names:
603+
# Build a rename map for every colliding function name
604+
all_reserved = declared_names | synthesized_func_names
605+
rename_map: dict[str, str] = {}
606+
for name in colliding_names:
607+
unique = _generate_unique_name(all_reserved)
608+
rename_map[name] = unique
609+
all_reserved.add(unique)
610+
611+
# Deep-copy the module body so we don't mutate the caller's AST,
612+
# then rename definitions and all references to them.
613+
module_body = copy.deepcopy(list(module.body))
614+
stub_module_body = ast.Module(body=module_body, type_ignores=[])
615+
_RenameTransformer(rename_map).visit(stub_module_body)
616+
module_body = stub_module_body.body
617+
tc_func_name = rename_map.get(func_name, func_name)
618+
else:
619+
module_body = list(module.body)
620+
tc_func_name = func_name
621+
530622
param_types = expected_params
531623
expected_callable_type: type = typing.cast(
532624
type,
@@ -539,15 +631,15 @@ def mypy_type_check(
539631
postlude = ast.AnnAssign(
540632
target=ast.Name(id="_synthesized_check", ctx=ast.Store()),
541633
annotation=expected_callable_ast,
542-
value=ast.Name(id=func_name, ctx=ast.Load()),
634+
value=ast.Name(id=tc_func_name, ctx=ast.Load()),
543635
simple=1,
544636
)
545637
full_body = (
546638
baseline_imports
547639
+ list(imports)
548640
+ list(stubs)
549641
+ list(variables)
550-
+ list(module.body)
642+
+ module_body
551643
+ [postlude]
552644
)
553645
stub_module = ast.Module(body=full_body, type_ignores=[])

tests/test_handlers_llm_type_checking.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import inspect
5+
import textwrap
56
import types
67
import typing
78
from collections import ChainMap
@@ -1267,3 +1268,122 @@ class MyErr(Exception):
12671268
ctx = ChainMap({"MyErr": MyErr}, get_context())
12681269
with pytest.raises(TypeError):
12691270
mypy_type_check(module, ctx, [], MyErr)
1271+
1272+
1273+
class TestMypyTypeCheckNameCollision:
1274+
"""Tests that mypy_type_check renames synthesized functions whose names
1275+
collide with variable declarations or class stubs from the context."""
1276+
1277+
def test_single_function_collides_with_variable(self):
1278+
"""Function name matches a variable in context; should still pass type-check."""
1279+
count_char = lambda s: s.count("a") # noqa: E731, F841
1280+
1281+
source = textwrap.dedent("""\
1282+
def count_char(s: str) -> int:
1283+
return s.count('a')
1284+
""")
1285+
module = ast.parse(source)
1286+
ctx = get_context()
1287+
# Should NOT raise — the collision is handled by renaming
1288+
mypy_type_check(module, ctx, [str], int)
1289+
1290+
def test_colliding_function_still_detects_type_errors(self):
1291+
"""Even after renaming, real type errors are still caught."""
1292+
count_char = lambda s: s.count("a") # noqa: E731, F841
1293+
1294+
source = textwrap.dedent("""\
1295+
def count_char(s: str) -> int:
1296+
return s # wrong return type
1297+
""")
1298+
module = ast.parse(source)
1299+
ctx = get_context()
1300+
with pytest.raises(TypeError):
1301+
mypy_type_check(module, ctx, [str], int)
1302+
1303+
def test_no_collision_passes_normally(self):
1304+
"""No name collision — normal type-check should work as before."""
1305+
x = 42 # noqa: F841
1306+
1307+
source = textwrap.dedent("""\
1308+
def some_unique_func(s: str) -> int:
1309+
return len(s)
1310+
""")
1311+
module = ast.parse(source)
1312+
ctx = get_context()
1313+
mypy_type_check(module, ctx, [str], int)
1314+
1315+
def test_multiple_functions_one_collides(self):
1316+
"""Module has helper + main function; only main collides with context."""
1317+
process = "some_value" # noqa: F841
1318+
1319+
source = textwrap.dedent("""\
1320+
def helper(x: int) -> str:
1321+
return str(x)
1322+
def process(items: list[int]) -> list[str]:
1323+
return [helper(i) for i in items]
1324+
""")
1325+
module = ast.parse(source)
1326+
ctx = get_context()
1327+
mypy_type_check(module, ctx, [list[int]], list[str])
1328+
1329+
def test_multiple_functions_both_collide(self):
1330+
"""Both helper and main function names collide with context variables."""
1331+
helper = lambda: None # noqa: E731, F841
1332+
compute = 123 # noqa: F841
1333+
1334+
source = textwrap.dedent("""\
1335+
def helper(x: int) -> str:
1336+
return str(x)
1337+
def compute(n: int) -> str:
1338+
return helper(n)
1339+
""")
1340+
module = ast.parse(source)
1341+
ctx = get_context()
1342+
mypy_type_check(module, ctx, [int], str)
1343+
1344+
def test_collision_with_class_stub(self):
1345+
"""Function name collides with a runtime class stub in context."""
1346+
1347+
class MyModel:
1348+
value: int
1349+
1350+
# Also define a function named MyModel in synthesized code
1351+
source = textwrap.dedent("""\
1352+
def MyModel(x: int) -> int:
1353+
return x * 2
1354+
""")
1355+
module = ast.parse(source)
1356+
ctx = ChainMap({"MyModel": MyModel}, get_context())
1357+
mypy_type_check(module, ctx, [int], int)
1358+
1359+
def test_collision_does_not_mutate_original_ast(self):
1360+
"""Renaming should not modify the original module AST."""
1361+
count_char = lambda s: s.count("a") # noqa: E731, F841
1362+
1363+
source = textwrap.dedent("""\
1364+
def count_char(s: str) -> int:
1365+
return s.count('a')
1366+
""")
1367+
module = ast.parse(source)
1368+
original_name = module.body[-1].name
1369+
1370+
ctx = get_context()
1371+
mypy_type_check(module, ctx, [str], int)
1372+
1373+
# Original AST must be untouched
1374+
assert module.body[-1].name == original_name
1375+
1376+
def test_helper_reference_updated_after_rename(self):
1377+
"""When a helper function is renamed, calls to it inside other
1378+
functions are also updated so mypy still sees valid code."""
1379+
validate = True # noqa: F841 — collides with helper name
1380+
1381+
source = textwrap.dedent("""\
1382+
def validate(x: int) -> bool:
1383+
return x > 0
1384+
def run(x: int) -> bool:
1385+
return validate(x)
1386+
""")
1387+
module = ast.parse(source)
1388+
ctx = get_context()
1389+
mypy_type_check(module, ctx, [int], bool)

0 commit comments

Comments
 (0)