Skip to content

Commit 8a07c5e

Browse files
committed
perf: optimize context extraction pipeline (~2x speedup)
Eliminate redundant CST traversals in code context extraction by caching dependency data, skipping unnecessary transforms, and removing MetadataWrapper.
1 parent 2c4989c commit 8a07c5e

3 files changed

Lines changed: 46 additions & 19 deletions

File tree

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,10 @@ def extract_all_contexts_from_files(
395395
except ValueError:
396396
relative_path = file_path
397397

398-
cleaned = remove_unused_definitions_by_function_names(original_module, hoh_names)
398+
# Collect definitions + dependencies once (expensive CST traversal), reuse for mark pass
399+
base_defs = collect_top_level_defs_with_dependencies(original_module)
400+
hoh_defs = mark_defs_for_functions(base_defs, hoh_names)
401+
cleaned = remove_unused_definitions_by_function_names(original_module, hoh_names, defs_with_usages=hoh_defs)
399402

400403
# Pre-compute source imports once for this file
401404
src_gathered = gather_source_imports(original_module, file_path, project_root_path)

codeflash/languages/python/context/unused_definition_remover.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,6 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
165165
class DependencyCollector(cst.CSTVisitor):
166166
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""
167167

168-
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
169-
170168
def __init__(self, definitions: dict[str, UsageInfo]) -> None:
171169
super().__init__()
172170
self.definitions = definitions
@@ -179,6 +177,8 @@ def __init__(self, definitions: dict[str, UsageInfo]) -> None:
179177
# Track if we're processing a top-level variable
180178
self.processing_variable = False
181179
self.current_variable_names = set()
180+
# Track Name nodes that are the .attr part of Attribute nodes (by id)
181+
self.attr_name_ids: set[int] = set()
182182

183183
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
184184
function_name = node.name.value
@@ -281,6 +281,12 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> None:
281281
self.processing_variable = False
282282
self.current_variable_names.clear()
283283

284+
def visit_Attribute(self, node: cst.Attribute) -> None:
285+
self.attr_name_ids.add(id(node.attr))
286+
287+
def leave_Attribute(self, original_node: cst.Attribute) -> None:
288+
self.attr_name_ids.discard(id(original_node.attr))
289+
284290
def visit_Name(self, node: cst.Name) -> None:
285291
name = node.value
286292

@@ -296,15 +302,11 @@ def visit_Name(self, node: cst.Name) -> None:
296302
# Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x')
297303
# We only want to track the base/value of attribute access, not the attribute name itself
298304
if self.class_depth > 0:
299-
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
300-
if parent is not None and isinstance(parent, cst.Attribute):
301-
# Check if this Name is the .attr (property name), not the .value (base)
302-
# If it's the .attr, skip it - attribute names aren't references to definitions
303-
if parent.attr is node:
304-
return
305-
# If it's the .value (base), only skip if it's self/cls
306-
if name in ("self", "cls"):
307-
return
305+
if id(node) in self.attr_name_ids:
306+
return
307+
# If it's the .value (base), only skip if it's self/cls
308+
if name in ("self", "cls"):
309+
return
308310
self.definitions[self.current_top_level_name].dependencies.add(name)
309311

310312

@@ -409,17 +411,16 @@ def remove_unused_definitions_recursively(
409411

410412

411413
def collect_top_level_defs_with_dependencies(code: Union[str, cst.Module]) -> dict[str, UsageInfo]:
412-
"""Collect all top level definitions and their inter-definition dependencies (expensive CST traversal).
414+
"""Collect all top level definitions and their inter-definition dependencies via CST traversal.
413415
414416
Returns a definitions dict with dependencies populated but no usage marks set.
415417
This result can be reused across multiple mark_defs_for_functions calls to avoid
416-
repeating the expensive MetadataWrapper + DependencyCollector traversal.
418+
repeating the DependencyCollector traversal.
417419
"""
418420
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
419421
definitions = collect_top_level_definitions(module)
420-
wrapper = cst.MetadataWrapper(module)
421422
dependency_collector = DependencyCollector(definitions)
422-
wrapper.visit(dependency_collector)
423+
module.visit(dependency_collector)
423424
return definitions
424425

425426

codeflash/languages/python/static_analysis/code_extractor.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,31 @@ def leave_ImportFrom(
426426
return updated_node
427427

428428

429+
def _has_aliased_future_imports(module: cst.Module) -> bool:
430+
for stmt in module.body:
431+
if isinstance(stmt, cst.SimpleStatementLine):
432+
for s in stmt.body:
433+
if (
434+
isinstance(s, cst.ImportFrom)
435+
and s.module is not None
436+
and isinstance(s.module, cst.Attribute | cst.Name)
437+
and hasattr(s.module, "value")
438+
and s.module.value == "__future__"
439+
and isinstance(s.names, (list, tuple))
440+
and any(name.asname is not None for name in s.names)
441+
):
442+
return True
443+
return False
444+
445+
446+
def _strip_future_aliases(module: cst.Module) -> cst.Module:
447+
if _has_aliased_future_imports(module):
448+
return module.visit(FutureAliasedImportTransformer())
449+
return module
450+
451+
429452
def delete___future___aliased_imports(module_code: str) -> str:
430-
return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code
453+
return _strip_future_aliases(cst.parse_module(module_code)).code
431454

432455

433456
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
@@ -555,9 +578,9 @@ def gather_source_imports(
555578
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
556579
try:
557580
if isinstance(src_module_code, cst.Module):
558-
src_module = src_module_code.visit(FutureAliasedImportTransformer())
581+
src_module = _strip_future_aliases(src_module_code)
559582
else:
560-
src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer())
583+
src_module = _strip_future_aliases(cst.parse_module(src_module_code))
561584

562585
has_module_level_imports = any(
563586
isinstance(s, (cst.Import, cst.ImportFrom))

0 commit comments

Comments
 (0)