Skip to content

Commit f91278c

Browse files
authored
Merge pull request #1641 from codeflash-ai/fix-importlens-bugs
fix: importlens optimization bugs
2 parents 5a73fe2 + c32abd1 commit f91278c

16 files changed

Lines changed: 776 additions & 219 deletions

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def process_file_context(
368368
try:
369369
all_names = primary_qualified_names | secondary_qualified_names
370370
code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, all_names)
371-
code_context = parse_code_and_prune_cst(
371+
pruned_module = parse_code_and_prune_cst(
372372
code_without_unused_defs,
373373
code_context_type,
374374
primary_qualified_names,
@@ -379,11 +379,13 @@ def process_file_context(
379379
logger.debug(f"Error while getting read-only code: {e}")
380380
return None
381381

382-
if code_context.strip():
383-
if code_context_type != CodeContextType.HASHING:
382+
if pruned_module.code.strip():
383+
if code_context_type == CodeContextType.HASHING:
384+
code_context = ast.unparse(ast.parse(pruned_module.code))
385+
else:
384386
code_context = add_needed_imports_from_module(
385387
src_module_code=original_code,
386-
dst_module_code=code_context,
388+
dst_module_code=pruned_module,
387389
src_path=file_path,
388390
dst_path=file_path,
389391
project_root=project_root_path,
@@ -1280,8 +1282,8 @@ def parse_code_and_prune_cst(
12801282
target_functions: set[str],
12811283
helpers_of_helper_functions: set[str] = set(), # noqa: B006
12821284
remove_docstrings: bool = False,
1283-
) -> str:
1284-
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
1285+
) -> cst.Module:
1286+
"""Parse and filter the code CST, returning the pruned Module."""
12851287
module = cst.parse_module(code)
12861288
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
12871289

@@ -1317,11 +1319,8 @@ def parse_code_and_prune_cst(
13171319
if not found_target:
13181320
raise ValueError("No target functions found in the provided code")
13191321
if filtered_node and isinstance(filtered_node, cst.Module):
1320-
code = str(filtered_node.code)
1321-
if code_context_type == CodeContextType.HASHING:
1322-
code = ast.unparse(ast.parse(code)) # Makes it standard
1323-
return code
1324-
return ""
1322+
return filtered_node
1323+
raise ValueError("Pruning produced no module")
13251324

13261325

13271326
def prune_cst(

codeflash/languages/python/static_analysis/code_extractor.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
684684

685685
def add_needed_imports_from_module(
686686
src_module_code: str,
687-
dst_module_code: str,
687+
dst_module_code: str | cst.Module,
688688
src_path: Path,
689689
dst_path: Path,
690690
project_root: Path,
@@ -696,6 +696,8 @@ def add_needed_imports_from_module(
696696
if not helper_functions_fqn:
697697
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}
698698

699+
dst_code_fallback = dst_module_code if isinstance(dst_module_code, str) else dst_module_code.code
700+
699701
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
700702
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
701703

@@ -715,15 +717,19 @@ def add_needed_imports_from_module(
715717
cst.parse_module(src_module_code).visit(gatherer)
716718
except Exception as e:
717719
logger.error(f"Error parsing source module code: {e}")
718-
return dst_module_code
720+
return dst_code_fallback
719721

720722
dotted_import_collector = DottedImportCollector()
721-
try:
722-
parsed_dst_module = cst.parse_module(dst_module_code)
723+
if isinstance(dst_module_code, cst.Module):
724+
parsed_dst_module = dst_module_code
723725
parsed_dst_module.visit(dotted_import_collector)
724-
except cst.ParserSyntaxError as e:
725-
logger.exception(f"Syntax error in destination module code: {e}")
726-
return dst_module_code # Return the original code if there's a syntax error
726+
else:
727+
try:
728+
parsed_dst_module = cst.parse_module(dst_module_code)
729+
parsed_dst_module.visit(dotted_import_collector)
730+
except cst.ParserSyntaxError as e:
731+
logger.exception(f"Syntax error in destination module code: {e}")
732+
return dst_code_fallback
727733

728734
try:
729735
for mod in gatherer.module_imports:
@@ -768,7 +774,7 @@ def add_needed_imports_from_module(
768774
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
769775
except Exception as e:
770776
logger.exception(f"Error adding imports to destination module code: {e}")
771-
return dst_module_code
777+
return dst_code_fallback
772778

773779
for mod, asname in gatherer.module_aliases.items():
774780
if not asname:
@@ -796,7 +802,7 @@ def add_needed_imports_from_module(
796802
return transformed_module.code.lstrip("\n")
797803
except Exception as e:
798804
logger.exception(f"Error adding imports to destination module code: {e}")
799-
return dst_module_code
805+
return dst_code_fallback
800806

801807

802808
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:

codeflash/optimization/function_optimizer.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ def process_single_candidate(
10231023
eval_ctx: CandidateEvaluationContext,
10241024
exp_type: str,
10251025
function_references: str,
1026+
normalized_original: str,
10261027
) -> BestOptimization | None:
10271028
"""Process a single optimization candidate.
10281029
@@ -1033,8 +1034,24 @@ def process_single_candidate(
10331034
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
10341035
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
10351036

1036-
logger.info(f"h3|Optimization candidate {candidate_index}/{total_candidates}:")
10371037
candidate = candidate_node.candidate
1038+
1039+
normalized_code = normalize_code(candidate.source_code.flat.strip())
1040+
1041+
if normalized_code == normalized_original:
1042+
logger.info(f"h3|Candidate {candidate_index}/{total_candidates}: Identical to original code, skipping.")
1043+
console.rule()
1044+
return None
1045+
1046+
if normalized_code in eval_ctx.ast_code_to_id:
1047+
logger.info(
1048+
f"h3|Candidate {candidate_index}/{total_candidates}: Duplicate of a previous candidate, skipping."
1049+
)
1050+
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context)
1051+
console.rule()
1052+
return None
1053+
1054+
logger.info(f"h3|Optimization candidate {candidate_index}/{total_candidates}:")
10381055
# Use correct extension based on language
10391056
ext = self.language_support.file_extensions[0]
10401057
code_print(
@@ -1062,13 +1079,6 @@ def process_single_candidate(
10621079
)
10631080
return None
10641081

1065-
# Check for duplicate candidates
1066-
normalized_code = normalize_code(candidate.source_code.flat.strip())
1067-
if normalized_code in eval_ctx.ast_code_to_id:
1068-
logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.")
1069-
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context)
1070-
return None
1071-
10721082
eval_ctx.register_new_candidate(normalized_code, candidate, code_context)
10731083

10741084
# Run the optimized candidate
@@ -1238,6 +1248,7 @@ def determine_best_candidate(
12381248
self.future_adaptive_optimizations,
12391249
)
12401250
candidate_index = 0
1251+
normalized_original = normalize_code(code_context.read_writable_code.flat.strip())
12411252

12421253
# Process candidates using queue-based approach
12431254
while not processor.is_done():
@@ -1259,6 +1270,7 @@ def determine_best_candidate(
12591270
eval_ctx=eval_ctx,
12601271
exp_type=exp_type,
12611272
function_references=function_references,
1273+
normalized_original=normalized_original,
12621274
)
12631275
except KeyboardInterrupt as e:
12641276
logger.exception(f"Optimization interrupted: {e}")

0 commit comments

Comments
 (0)