Skip to content

Commit 9d23a0e

Browse files
authored
Merge pull request #1518 from codeflash-ai/proper-async
refactor: inline async decorators to remove codeflash import dependency
2 parents 8cb7209 + 6c092b5 commit 9d23a0e

5 files changed

Lines changed: 404 additions & 228 deletions

File tree

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 195 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,73 +1497,207 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca
14971497
return False
14981498

14991499

1500-
class AsyncDecoratorImportAdder(cst.CSTTransformer):
1501-
"""Transformer that adds the import for async decorators."""
1500+
ASYNC_HELPER_INLINE_CODE = """import asyncio
1501+
import gc
1502+
import os
1503+
import sqlite3
1504+
import time
1505+
from functools import wraps
1506+
from pathlib import Path
1507+
from tempfile import TemporaryDirectory
1508+
1509+
import dill as pickle
15021510
1503-
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
1504-
self.mode = mode
1505-
self.has_import = False
1506-
1507-
def _get_decorator_name(self) -> str:
1508-
"""Get the decorator name based on the testing mode."""
1509-
if self.mode == TestingMode.BEHAVIOR:
1510-
return "codeflash_behavior_async"
1511-
if self.mode == TestingMode.CONCURRENCY:
1512-
return "codeflash_concurrency_async"
1513-
return "codeflash_performance_async"
1514-
1515-
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
1516-
# Check if the async decorator import is already present
1517-
if (
1518-
isinstance(node.module, cst.Attribute)
1519-
and isinstance(node.module.value, cst.Attribute)
1520-
and isinstance(node.module.value.value, cst.Name)
1521-
and node.module.value.value.value == "codeflash"
1522-
and node.module.value.attr.value == "code_utils"
1523-
and node.module.attr.value == "codeflash_wrap_decorator"
1524-
and not isinstance(node.names, cst.ImportStar)
1525-
):
1526-
decorator_name = self._get_decorator_name()
1527-
for import_alias in node.names:
1528-
if import_alias.name.value == decorator_name:
1529-
self.has_import = True
1530-
1531-
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
1532-
# If the import is already there, don't add it again
1533-
if self.has_import:
1534-
return updated_node
1535-
1536-
# Choose import based on mode
1537-
decorator_name = self._get_decorator_name()
1538-
1539-
# Parse the import statement into a CST node
1540-
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
1541-
1542-
# Add the import to the module's body
1543-
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
1511+
1512+
def get_run_tmp_file(file_path):
1513+
if not hasattr(get_run_tmp_file, "tmpdir"):
1514+
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
1515+
return Path(get_run_tmp_file.tmpdir.name) / file_path
1516+
1517+
1518+
def extract_test_context_from_env():
1519+
test_module = os.environ["CODEFLASH_TEST_MODULE"]
1520+
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
1521+
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
1522+
if test_module and test_function:
1523+
return (test_module, test_class if test_class else None, test_function)
1524+
raise RuntimeError(
1525+
"Test context environment variables not set - ensure tests are run through codeflash test runner"
1526+
)
1527+
1528+
1529+
def codeflash_behavior_async(func):
1530+
@wraps(func)
1531+
async def async_wrapper(*args, **kwargs):
1532+
loop = asyncio.get_running_loop()
1533+
function_name = func.__name__
1534+
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
1535+
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
1536+
test_module_name, test_class_name, test_name = extract_test_context_from_env()
1537+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
1538+
if not hasattr(async_wrapper, "index"):
1539+
async_wrapper.index = {}
1540+
if test_id in async_wrapper.index:
1541+
async_wrapper.index[test_id] += 1
1542+
else:
1543+
async_wrapper.index[test_id] = 0
1544+
codeflash_test_index = async_wrapper.index[test_id]
1545+
invocation_id = f"{line_id}_{codeflash_test_index}"
1546+
class_prefix = (test_class_name + ".") if test_class_name else ""
1547+
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
1548+
print(f"!$######{test_stdout_tag}######$!")
1549+
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
1550+
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
1551+
codeflash_con = sqlite3.connect(db_path)
1552+
codeflash_cur = codeflash_con.cursor()
1553+
codeflash_cur.execute(
1554+
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
1555+
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
1556+
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
1557+
)
1558+
exception = None
1559+
counter = loop.time()
1560+
gc.disable()
1561+
try:
1562+
ret = func(*args, **kwargs)
1563+
counter = loop.time()
1564+
return_value = await ret
1565+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1566+
except Exception as e:
1567+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1568+
exception = e
1569+
finally:
1570+
gc.enable()
1571+
print(f"!######{test_stdout_tag}######!")
1572+
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
1573+
codeflash_cur.execute(
1574+
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
1575+
(
1576+
test_module_name,
1577+
test_class_name,
1578+
test_name,
1579+
function_name,
1580+
loop_index,
1581+
invocation_id,
1582+
codeflash_duration,
1583+
pickled_return_value,
1584+
"function_call",
1585+
),
1586+
)
1587+
codeflash_con.commit()
1588+
codeflash_con.close()
1589+
if exception:
1590+
raise exception
1591+
return return_value
1592+
return async_wrapper
1593+
1594+
1595+
def codeflash_performance_async(func):
1596+
@wraps(func)
1597+
async def async_wrapper(*args, **kwargs):
1598+
loop = asyncio.get_running_loop()
1599+
function_name = func.__name__
1600+
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
1601+
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
1602+
test_module_name, test_class_name, test_name = extract_test_context_from_env()
1603+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
1604+
if not hasattr(async_wrapper, "index"):
1605+
async_wrapper.index = {}
1606+
if test_id in async_wrapper.index:
1607+
async_wrapper.index[test_id] += 1
1608+
else:
1609+
async_wrapper.index[test_id] = 0
1610+
codeflash_test_index = async_wrapper.index[test_id]
1611+
invocation_id = f"{line_id}_{codeflash_test_index}"
1612+
class_prefix = (test_class_name + ".") if test_class_name else ""
1613+
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
1614+
print(f"!$######{test_stdout_tag}######$!")
1615+
exception = None
1616+
counter = loop.time()
1617+
gc.disable()
1618+
try:
1619+
ret = func(*args, **kwargs)
1620+
counter = loop.time()
1621+
return_value = await ret
1622+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1623+
except Exception as e:
1624+
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1625+
exception = e
1626+
finally:
1627+
gc.enable()
1628+
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
1629+
if exception:
1630+
raise exception
1631+
return return_value
1632+
return async_wrapper
1633+
1634+
1635+
def codeflash_concurrency_async(func):
1636+
@wraps(func)
1637+
async def async_wrapper(*args, **kwargs):
1638+
function_name = func.__name__
1639+
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
1640+
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
1641+
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
1642+
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
1643+
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
1644+
gc.disable()
1645+
try:
1646+
seq_start = time.perf_counter_ns()
1647+
for _ in range(concurrency_factor):
1648+
result = await func(*args, **kwargs)
1649+
sequential_time = time.perf_counter_ns() - seq_start
1650+
finally:
1651+
gc.enable()
1652+
gc.disable()
1653+
try:
1654+
conc_start = time.perf_counter_ns()
1655+
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
1656+
await asyncio.gather(*tasks)
1657+
concurrent_time = time.perf_counter_ns() - conc_start
1658+
finally:
1659+
gc.enable()
1660+
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
1661+
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
1662+
return result
1663+
return async_wrapper
1664+
"""
1665+
1666+
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
1667+
1668+
1669+
def get_decorator_name_for_mode(mode: TestingMode) -> str:
1670+
if mode == TestingMode.BEHAVIOR:
1671+
return "codeflash_behavior_async"
1672+
if mode == TestingMode.CONCURRENCY:
1673+
return "codeflash_concurrency_async"
1674+
return "codeflash_performance_async"
1675+
1676+
1677+
def write_async_helper_file(target_dir: Path) -> Path:
1678+
"""Write the async decorator helper file to the target directory."""
1679+
helper_path = target_dir / ASYNC_HELPER_FILENAME
1680+
if not helper_path.exists():
1681+
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
1682+
return helper_path
15441683

15451684

15461685
def add_async_decorator_to_function(
1547-
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1686+
source_path: Path,
1687+
function: FunctionToOptimize,
1688+
mode: TestingMode = TestingMode.BEHAVIOR,
1689+
project_root: Path | None = None,
15481690
) -> bool:
15491691
"""Add async decorator to an async function definition and write back to file.
15501692
1551-
Args:
1552-
----
1553-
source_path: Path to the source file to modify in-place.
1554-
function: The FunctionToOptimize object representing the target async function.
1555-
mode: The testing mode to determine which decorator to apply.
1556-
1557-
Returns:
1558-
-------
1559-
Boolean indicating whether the decorator was successfully added.
1693+
Writes a helper file containing the decorator implementation to project_root (or source directory
1694+
as fallback) and adds a standard import + decorator to the source file.
15601695
15611696
"""
15621697
if not function.is_async:
15631698
return False
15641699

15651700
try:
1566-
# Read source code
15671701
with source_path.open(encoding="utf8") as f:
15681702
source_code = f.read()
15691703

@@ -1573,10 +1707,14 @@ def add_async_decorator_to_function(
15731707
decorator_transformer = AsyncDecoratorAdder(function, mode)
15741708
module = module.visit(decorator_transformer)
15751709

1576-
# Add the import if decorator was added
15771710
if decorator_transformer.added_decorator:
1578-
import_transformer = AsyncDecoratorImportAdder(mode)
1579-
module = module.visit(import_transformer)
1711+
# Write the helper file to project_root (on sys.path) or source dir as fallback
1712+
helper_dir = project_root if project_root is not None else source_path.parent
1713+
write_async_helper_file(helper_dir)
1714+
# Add the import via CST so sort_imports can place it correctly
1715+
decorator_name = get_decorator_name_for_mode(mode)
1716+
import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}")
1717+
module = module.with_changes(body=[import_node, *list(module.body)])
15801718

15811719
modified_code = sort_imports(code=module.code, float_to_top=True)
15821720
except Exception as e:

codeflash/optimization/function_optimizer.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1897,6 +1897,7 @@ def setup_and_establish_baseline(
18971897
if self.args.override_fixtures:
18981898
restore_conftest(original_conftest_content)
18991899
cleanup_paths(paths_to_cleanup)
1900+
self.cleanup_async_helper_file()
19001901
return Failure(baseline_result.failure())
19011902

19021903
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
@@ -1908,6 +1909,7 @@ def setup_and_establish_baseline(
19081909
if self.args.override_fixtures:
19091910
restore_conftest(original_conftest_content)
19101911
cleanup_paths(paths_to_cleanup)
1912+
self.cleanup_async_helper_file()
19111913
return Failure("The threshold for test confidence was not met.")
19121914

19131915
return Success(
@@ -2279,6 +2281,13 @@ def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None
22792281
self.write_code_and_helpers(
22802282
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
22812283
)
2284+
self.cleanup_async_helper_file()
2285+
2286+
def cleanup_async_helper_file(self) -> None:
2287+
from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME
2288+
2289+
helper_path = self.project_root / ASYNC_HELPER_FILENAME
2290+
helper_path.unlink(missing_ok=True)
22822291

22832292
def establish_original_code_baseline(
22842293
self,
@@ -2296,7 +2305,10 @@ def establish_original_code_baseline(
22962305
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
22972306

22982307
success = add_async_decorator_to_function(
2299-
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
2308+
self.function_to_optimize.file_path,
2309+
self.function_to_optimize,
2310+
TestingMode.BEHAVIOR,
2311+
project_root=self.project_root,
23002312
)
23012313

23022314
# Instrument codeflash capture
@@ -2361,7 +2373,10 @@ def establish_original_code_baseline(
23612373
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
23622374

23632375
add_async_decorator_to_function(
2364-
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
2376+
self.function_to_optimize.file_path,
2377+
self.function_to_optimize,
2378+
TestingMode.PERFORMANCE,
2379+
project_root=self.project_root,
23652380
)
23662381

23672382
try:
@@ -2535,7 +2550,10 @@ def run_optimized_candidate(
25352550
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
25362551

25372552
add_async_decorator_to_function(
2538-
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
2553+
self.function_to_optimize.file_path,
2554+
self.function_to_optimize,
2555+
TestingMode.BEHAVIOR,
2556+
project_root=self.project_root,
25392557
)
25402558

25412559
try:
@@ -2611,7 +2629,10 @@ def run_optimized_candidate(
26112629
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
26122630

26132631
add_async_decorator_to_function(
2614-
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
2632+
self.function_to_optimize.file_path,
2633+
self.function_to_optimize,
2634+
TestingMode.PERFORMANCE,
2635+
project_root=self.project_root,
26152636
)
26162637

26172638
try:
@@ -2974,7 +2995,10 @@ def run_concurrency_benchmark(
29742995
try:
29752996
# Add concurrency decorator to the source function
29762997
add_async_decorator_to_function(
2977-
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
2998+
self.function_to_optimize.file_path,
2999+
self.function_to_optimize,
3000+
TestingMode.CONCURRENCY,
3001+
project_root=self.project_root,
29783002
)
29793003

29803004
# Run the concurrency benchmark tests

tests/scripts/end_to_end_test_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool:
1313
CoverageExpectation(
1414
function_name="retry_with_backoff",
1515
expected_coverage=100.0,
16-
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
16+
expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
1717
)
1818
],
1919
)

0 commit comments

Comments
 (0)