Skip to content

Commit 0e91155

Browse files
committed
perf: optimize context extraction with caching and import pre-filtering (~42% faster)
Cache Jedi refs_by_parent across calls, reuse RO results for identical TESTGEN pruning, pre-filter imports by AST-collected referenced names, and skip RemoveImportsVisitor when dst has no pre-existing imports.
1 parent 8a07c5e commit 0e91155

3 files changed

Lines changed: 132 additions & 57 deletions

File tree

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 71 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,11 @@ def get_code_optimization_context(
118118

119119
# Get FunctionSource representation of helpers of FTO
120120
fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}}
121+
jedi_refs_cache: dict[Path, dict[str, list]] | None = None
121122
if call_graph is not None:
122123
helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input)
123124
else:
124-
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
125+
helpers_of_fto_dict, helpers_of_fto_list, jedi_refs_cache = get_function_sources_from_jedi(
125126
fto_input, project_root_path, jedi_project=jedi_project
126127
)
127128

@@ -141,8 +142,8 @@ def get_code_optimization_context(
141142
for qualified_names in helpers_of_fto_qualified_names_dict.values():
142143
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})
143144

144-
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
145-
helpers_of_fto_qualified_names_dict, project_root_path, jedi_project=jedi_project
145+
helpers_of_helpers_dict, helpers_of_helpers_list, _ = get_function_sources_from_jedi(
146+
helpers_of_fto_qualified_names_dict, project_root_path, jedi_project=jedi_project, refs_cache=jedi_refs_cache
146147
)
147148

148149
# Extract all code contexts in a single pass (one CST parse per file)
@@ -312,12 +313,15 @@ def extract_all_contexts_from_files(
312313
logger.debug(f"Error while getting read-writable code: {e}")
313314

314315
# READ_ONLY
316+
fto_ro_code_result: str | None = None
317+
fto_ro_pruned_code_str: str | None = None
315318
try:
316319
ro_pruned = parse_code_and_prune_cst(
317320
all_cleaned, CodeContextType.READ_ONLY, fto_names, hoh_names, remove_docstrings=False
318321
)
319-
if ro_pruned.code.strip():
320-
ro_code = add_needed_imports_from_module(
322+
fto_ro_pruned_code_str = ro_pruned.code.strip()
323+
if fto_ro_pruned_code_str:
324+
fto_ro_code_result = add_needed_imports_from_module(
321325
src_module_code=original_module,
322326
dst_module_code=ro_pruned,
323327
src_path=file_path,
@@ -326,7 +330,7 @@ def extract_all_contexts_from_files(
326330
helper_functions=all_helper_functions,
327331
gathered_imports=src_gathered,
328332
)
329-
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
333+
ro.code_strings.append(CodeString(code=fto_ro_code_result, file_path=relative_path))
330334
except ValueError as e:
331335
logger.debug(f"Error while getting read-only code: {e}")
332336

@@ -341,21 +345,25 @@ def extract_all_contexts_from_files(
341345
except ValueError as e:
342346
logger.debug(f"Error while getting hashing code: {e}")
343347

344-
# TESTGEN
348+
# TESTGEN -- reuse RO result when pruned code is identical
345349
try:
346350
testgen_pruned = parse_code_and_prune_cst(
347351
all_cleaned, CodeContextType.TESTGEN, fto_names, hoh_names, remove_docstrings=False
348352
)
349-
if testgen_pruned.code.strip():
350-
testgen_code = add_needed_imports_from_module(
351-
src_module_code=original_module,
352-
dst_module_code=testgen_pruned,
353-
src_path=file_path,
354-
dst_path=file_path,
355-
project_root=project_root_path,
356-
helper_functions=all_helper_functions,
357-
gathered_imports=src_gathered,
358-
)
353+
fto_testgen_pruned_code_str = testgen_pruned.code.strip()
354+
if fto_testgen_pruned_code_str:
355+
if fto_ro_code_result is not None and fto_testgen_pruned_code_str == fto_ro_pruned_code_str:
356+
testgen_code = fto_ro_code_result
357+
else:
358+
testgen_code = add_needed_imports_from_module(
359+
src_module_code=original_module,
360+
dst_module_code=testgen_pruned,
361+
src_path=file_path,
362+
dst_path=file_path,
363+
project_root=project_root_path,
364+
helper_functions=all_helper_functions,
365+
gathered_imports=src_gathered,
366+
)
359367
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
360368
except ValueError as e:
361369
logger.debug(f"Error while getting testgen code: {e}")
@@ -404,12 +412,15 @@ def extract_all_contexts_from_files(
404412
src_gathered = gather_source_imports(original_module, file_path, project_root_path)
405413

406414
# READ_ONLY
415+
ro_code_result: str | None = None
416+
ro_pruned_code_str: str | None = None
407417
try:
408418
ro_pruned = parse_code_and_prune_cst(
409419
cleaned, CodeContextType.READ_ONLY, set(), hoh_names, remove_docstrings=False
410420
)
411-
if ro_pruned.code.strip():
412-
ro_code = add_needed_imports_from_module(
421+
ro_pruned_code_str = ro_pruned.code.strip()
422+
if ro_pruned_code_str:
423+
ro_code_result = add_needed_imports_from_module(
413424
src_module_code=original_module,
414425
dst_module_code=ro_pruned,
415426
src_path=file_path,
@@ -418,7 +429,7 @@ def extract_all_contexts_from_files(
418429
helper_functions=helper_functions,
419430
gathered_imports=src_gathered,
420431
)
421-
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
432+
ro.code_strings.append(CodeString(code=ro_code_result, file_path=relative_path))
422433
except ValueError as e:
423434
logger.debug(f"Error while getting read-only code: {e}")
424435

@@ -433,21 +444,25 @@ def extract_all_contexts_from_files(
433444
except ValueError as e:
434445
logger.debug(f"Error while getting hashing code: {e}")
435446

436-
# TESTGEN
447+
# TESTGEN -- reuse RO result when pruned code is identical (common for HoH-only files)
437448
try:
438449
testgen_pruned = parse_code_and_prune_cst(
439450
cleaned, CodeContextType.TESTGEN, set(), hoh_names, remove_docstrings=False
440451
)
441-
if testgen_pruned.code.strip():
442-
testgen_code = add_needed_imports_from_module(
443-
src_module_code=original_module,
444-
dst_module_code=testgen_pruned,
445-
src_path=file_path,
446-
dst_path=file_path,
447-
project_root=project_root_path,
448-
helper_functions=helper_functions,
449-
gathered_imports=src_gathered,
450-
)
452+
testgen_pruned_code_str = testgen_pruned.code.strip()
453+
if testgen_pruned_code_str:
454+
if ro_code_result is not None and testgen_pruned_code_str == ro_pruned_code_str:
455+
testgen_code = ro_code_result
456+
else:
457+
testgen_code = add_needed_imports_from_module(
458+
src_module_code=original_module,
459+
dst_module_code=testgen_pruned,
460+
src_path=file_path,
461+
dst_path=file_path,
462+
project_root=project_root_path,
463+
helper_functions=helper_functions,
464+
gathered_imports=src_gathered,
465+
)
451466
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
452467
except ValueError as e:
453468
logger.debug(f"Error while getting testgen code: {e}")
@@ -546,33 +561,39 @@ def get_function_sources_from_jedi(
546561
project_root_path: Path,
547562
*,
548563
jedi_project: object | None = None,
549-
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
564+
refs_cache: dict[Path, dict[str, list]] | None = None,
565+
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource], dict[Path, dict[str, list]]]:
550566
import jedi
551567

552568
project = jedi_project if jedi_project is not None else get_jedi_project(str(project_root_path))
553569
file_path_to_function_source = defaultdict(set)
554570
function_source_list: list[FunctionSource] = []
571+
new_refs_cache: dict[Path, dict[str, list]] = {} if refs_cache is None else dict(refs_cache)
555572
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
556-
script = jedi.Script(path=file_path, project=project)
557-
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
573+
if file_path in new_refs_cache:
574+
refs_by_parent = new_refs_cache[file_path]
575+
else:
576+
script = jedi.Script(path=file_path, project=project)
577+
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
558578

559-
# Pre-group references by their parent function's qualified name for O(1) lookup
560-
refs_by_parent: dict[str, list[Name]] = defaultdict(list)
561-
for ref in file_refs:
562-
if not ref.full_name:
563-
continue
564-
try:
565-
parent = ref.parent()
566-
if parent is None or parent.type != "function":
579+
# Pre-group references by their parent function's qualified name for O(1) lookup
580+
refs_by_parent: dict[str, list[Name]] = defaultdict(list)
581+
for ref in file_refs:
582+
if not ref.full_name:
567583
continue
568-
parent_qn = get_qualified_name(parent.module_name, parent.full_name)
569-
# Exclude self-references (recursive calls) — the ref's own qualified name matches the parent
570-
ref_qn = get_qualified_name(ref.module_name, ref.full_name)
571-
if ref_qn == parent_qn:
584+
try:
585+
parent = ref.parent()
586+
if parent is None or parent.type != "function":
587+
continue
588+
parent_qn = get_qualified_name(parent.module_name, parent.full_name)
589+
# Exclude self-references (recursive calls) — the ref's own qualified name matches the parent
590+
ref_qn = get_qualified_name(ref.module_name, ref.full_name)
591+
if ref_qn == parent_qn:
592+
continue
593+
refs_by_parent[parent_qn].append(ref)
594+
except (AttributeError, ValueError):
572595
continue
573-
refs_by_parent[parent_qn].append(ref)
574-
except (AttributeError, ValueError):
575-
continue
596+
new_refs_cache[file_path] = dict(refs_by_parent)
576597

577598
for qualified_function_name in qualified_function_names:
578599
names = refs_by_parent.get(qualified_function_name, [])
@@ -623,7 +644,7 @@ def get_function_sources_from_jedi(
623644
file_path_to_function_source[definition_path].add(function_source)
624645
function_source_list.append(function_source)
625646

626-
return file_path_to_function_source, function_source_list
647+
return file_path_to_function_source, function_source_list, new_refs_cache
627648

628649

629650
def _parse_and_collect_imports(code_context: CodeStringsMarkdown) -> tuple[ast.Module, dict[str, str]] | None:

codeflash/languages/python/static_analysis/code_extractor.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,39 @@ def gather_source_imports(
619619
return None
620620

621621

622+
def _collect_dst_referenced_names(dst_code: str) -> tuple[set[str], bool]:
623+
"""Collect all names referenced in destination code for import pre-filtering.
624+
625+
Uses ast (not libcst) for speed. Collects Name nodes and base names of Attribute chains,
626+
plus names inside string annotations.
627+
628+
Returns (names, has_imports) where has_imports indicates whether the dst has any
629+
pre-existing import statements.
630+
"""
631+
try:
632+
tree = ast.parse(dst_code)
633+
except SyntaxError:
634+
return set(), False
635+
names: set[str] = set()
636+
has_imports = False
637+
for node in ast.walk(tree):
638+
if isinstance(node, ast.Name):
639+
names.add(node.id)
640+
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
641+
names.add(node.value.id)
642+
elif isinstance(node, (ast.Import, ast.ImportFrom)):
643+
has_imports = True
644+
elif isinstance(node, ast.Constant) and isinstance(node.value, str):
645+
try:
646+
inner = ast.parse(node.value, mode="eval")
647+
for inner_node in ast.walk(inner):
648+
if isinstance(inner_node, ast.Name):
649+
names.add(inner_node.id)
650+
except SyntaxError:
651+
pass
652+
return names, has_imports
653+
654+
622655
def add_needed_imports_from_module(
623656
src_module_code: str | cst.Module,
624657
dst_module_code: str | cst.Module,
@@ -669,13 +702,20 @@ def add_needed_imports_from_module(
669702

670703
parsed_dst_module.visit(dotted_import_collector)
671704

705+
# Pre-filter: collect names referenced in destination code to avoid adding unused imports.
706+
# This keeps the intermediate module small so RemoveImportsVisitor's scope analysis is cheap.
707+
dst_code_str = parsed_dst_module.code if isinstance(parsed_dst_module, cst.Module) else dst_code_fallback
708+
dst_referenced_names, dst_has_imports = _collect_dst_referenced_names(dst_code_str)
709+
672710
try:
673711
for mod in gatherer.module_imports:
674712
# Skip __future__ imports as they cannot be imported directly
675713
# __future__ imports should only be imported with specific objects i.e from __future__ import annotations
676714
if mod == "__future__":
677715
continue
678-
if mod not in dotted_import_collector.imports:
716+
# For `import foo.bar`, the bound name is `foo`
717+
bound_name = mod.split(".")[0]
718+
if bound_name in dst_referenced_names and mod not in dotted_import_collector.imports:
679719
AddImportsVisitor.add_needed_import(dst_context, mod)
680720
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
681721
aliased_objects = set()
@@ -701,13 +741,18 @@ def add_needed_imports_from_module(
701741

702742
for symbol in resolved_symbols:
703743
if (
704-
f"{mod}.{symbol}" not in helper_functions_fqn
744+
symbol in dst_referenced_names
745+
and f"{mod}.{symbol}" not in helper_functions_fqn
705746
and f"{mod}.{symbol}" not in dotted_import_collector.imports
706747
):
707748
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
708749
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
709750
else:
710-
if f"{mod}.{obj}" not in dotted_import_collector.imports:
751+
# For `from foo import bar`, the bound name is `bar`
752+
# Always include __future__ imports -- they affect parsing behavior, not naming
753+
if (
754+
mod == "__future__" or obj in dst_referenced_names
755+
) and f"{mod}.{obj}" not in dotted_import_collector.imports:
711756
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
712757
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
713758
except Exception as e:
@@ -717,7 +762,8 @@ def add_needed_imports_from_module(
717762
for mod, asname in gatherer.module_aliases.items():
718763
if not asname:
719764
continue
720-
if f"{mod}.{asname}" not in dotted_import_collector.imports:
765+
# For `import foo as bar`, the bound name is `bar`
766+
if asname in dst_referenced_names and f"{mod}.{asname}" not in dotted_import_collector.imports:
721767
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
722768
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
723769

@@ -729,14 +775,22 @@ def add_needed_imports_from_module(
729775
if not alias_pair[0] or not alias_pair[1]:
730776
continue
731777

732-
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
778+
# For `from foo import bar as baz`, the bound name is `baz`
779+
if (
780+
alias_pair[1] in dst_referenced_names
781+
and f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports
782+
):
733783
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
734784
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
735785

736786
try:
737787
add_imports_visitor = AddImportsVisitor(dst_context)
738788
transformed_module = add_imports_visitor.transform_module(parsed_dst_module)
739-
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
789+
# Skip RemoveImportsVisitor when the dst had no pre-existing imports.
790+
# In that case, the only imports are those just added by AddImportsVisitor,
791+
# which are already pre-filtered to names referenced in the dst code.
792+
if dst_has_imports:
793+
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
740794
return transformed_module.code.lstrip("\n")
741795
except Exception as e:
742796
logger.exception(f"Error adding imports to destination module code: {e}")

codeflash/languages/python/support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def find_helper_functions(self, function: FunctionToOptimize, project_root: Path
351351
from codeflash.languages.python.context.code_context_extractor import get_function_sources_from_jedi
352352

353353
try:
354-
_dict, sources = get_function_sources_from_jedi(
354+
_dict, sources, _ = get_function_sources_from_jedi(
355355
{function.file_path: {function.qualified_name}}, project_root
356356
)
357357
except Exception as e:

0 commit comments

Comments
 (0)