@@ -1497,73 +1497,207 @@ def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Ca
14971497 return False
14981498
14991499
1500- class AsyncDecoratorImportAdder (cst .CSTTransformer ):
1501- """Transformer that adds the import for async decorators."""
1500+ ASYNC_HELPER_INLINE_CODE = """import asyncio
1501+ import gc
1502+ import os
1503+ import sqlite3
1504+ import time
1505+ from functools import wraps
1506+ from pathlib import Path
1507+ from tempfile import TemporaryDirectory
1508+
1509+ import dill as pickle
15021510
1503- def __init__ (self , mode : TestingMode = TestingMode .BEHAVIOR ) -> None :
1504- self .mode = mode
1505- self .has_import = False
1506-
1507- def _get_decorator_name (self ) -> str :
1508- """Get the decorator name based on the testing mode."""
1509- if self .mode == TestingMode .BEHAVIOR :
1510- return "codeflash_behavior_async"
1511- if self .mode == TestingMode .CONCURRENCY :
1512- return "codeflash_concurrency_async"
1513- return "codeflash_performance_async"
1514-
1515- def visit_ImportFrom (self , node : cst .ImportFrom ) -> None :
1516- # Check if the async decorator import is already present
1517- if (
1518- isinstance (node .module , cst .Attribute )
1519- and isinstance (node .module .value , cst .Attribute )
1520- and isinstance (node .module .value .value , cst .Name )
1521- and node .module .value .value .value == "codeflash"
1522- and node .module .value .attr .value == "code_utils"
1523- and node .module .attr .value == "codeflash_wrap_decorator"
1524- and not isinstance (node .names , cst .ImportStar )
1525- ):
1526- decorator_name = self ._get_decorator_name ()
1527- for import_alias in node .names :
1528- if import_alias .name .value == decorator_name :
1529- self .has_import = True
1530-
1531- def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
1532- # If the import is already there, don't add it again
1533- if self .has_import :
1534- return updated_node
1535-
1536- # Choose import based on mode
1537- decorator_name = self ._get_decorator_name ()
1538-
1539- # Parse the import statement into a CST node
1540- import_node = cst .parse_statement (f"from codeflash.code_utils.codeflash_wrap_decorator import { decorator_name } " )
1541-
1542- # Add the import to the module's body
1543- return updated_node .with_changes (body = [import_node , * list (updated_node .body )])
1511+
1512+ def get_run_tmp_file(file_path):
1513+ if not hasattr(get_run_tmp_file, "tmpdir"):
1514+ get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
1515+ return Path(get_run_tmp_file.tmpdir.name) / file_path
1516+
1517+
1518+ def extract_test_context_from_env():
1519+ test_module = os.environ["CODEFLASH_TEST_MODULE"]
1520+ test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
1521+ test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
1522+ if test_module and test_function:
1523+ return (test_module, test_class if test_class else None, test_function)
1524+ raise RuntimeError(
1525+ "Test context environment variables not set - ensure tests are run through codeflash test runner"
1526+ )
1527+
1528+
1529+ def codeflash_behavior_async(func):
1530+ @wraps(func)
1531+ async def async_wrapper(*args, **kwargs):
1532+ loop = asyncio.get_running_loop()
1533+ function_name = func.__name__
1534+ line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
1535+ loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
1536+ test_module_name, test_class_name, test_name = extract_test_context_from_env()
1537+ test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
1538+ if not hasattr(async_wrapper, "index"):
1539+ async_wrapper.index = {}
1540+ if test_id in async_wrapper.index:
1541+ async_wrapper.index[test_id] += 1
1542+ else:
1543+ async_wrapper.index[test_id] = 0
1544+ codeflash_test_index = async_wrapper.index[test_id]
1545+ invocation_id = f"{line_id}_{codeflash_test_index}"
1546+ class_prefix = (test_class_name + ".") if test_class_name else ""
1547+ test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
1548+ print(f"!$######{test_stdout_tag}######$!")
1549+ iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
1550+ db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
1551+ codeflash_con = sqlite3.connect(db_path)
1552+ codeflash_cur = codeflash_con.cursor()
1553+ codeflash_cur.execute(
1554+ "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
1555+ "test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
1556+ "runtime INTEGER, return_value BLOB, verification_type TEXT)"
1557+ )
1558+ exception = None
1559+ counter = loop.time()
1560+ gc.disable()
1561+ try:
1562+ ret = func(*args, **kwargs)
1563+ counter = loop.time()
1564+ return_value = await ret
1565+ codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1566+ except Exception as e:
1567+ codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1568+ exception = e
1569+ finally:
1570+ gc.enable()
1571+ print(f"!######{test_stdout_tag}######!")
1572+ pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
1573+ codeflash_cur.execute(
1574+ "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
1575+ (
1576+ test_module_name,
1577+ test_class_name,
1578+ test_name,
1579+ function_name,
1580+ loop_index,
1581+ invocation_id,
1582+ codeflash_duration,
1583+ pickled_return_value,
1584+ "function_call",
1585+ ),
1586+ )
1587+ codeflash_con.commit()
1588+ codeflash_con.close()
1589+ if exception:
1590+ raise exception
1591+ return return_value
1592+ return async_wrapper
1593+
1594+
1595+ def codeflash_performance_async(func):
1596+ @wraps(func)
1597+ async def async_wrapper(*args, **kwargs):
1598+ loop = asyncio.get_running_loop()
1599+ function_name = func.__name__
1600+ line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
1601+ loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
1602+ test_module_name, test_class_name, test_name = extract_test_context_from_env()
1603+ test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
1604+ if not hasattr(async_wrapper, "index"):
1605+ async_wrapper.index = {}
1606+ if test_id in async_wrapper.index:
1607+ async_wrapper.index[test_id] += 1
1608+ else:
1609+ async_wrapper.index[test_id] = 0
1610+ codeflash_test_index = async_wrapper.index[test_id]
1611+ invocation_id = f"{line_id}_{codeflash_test_index}"
1612+ class_prefix = (test_class_name + ".") if test_class_name else ""
1613+ test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
1614+ print(f"!$######{test_stdout_tag}######$!")
1615+ exception = None
1616+ counter = loop.time()
1617+ gc.disable()
1618+ try:
1619+ ret = func(*args, **kwargs)
1620+ counter = loop.time()
1621+ return_value = await ret
1622+ codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1623+ except Exception as e:
1624+ codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
1625+ exception = e
1626+ finally:
1627+ gc.enable()
1628+ print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
1629+ if exception:
1630+ raise exception
1631+ return return_value
1632+ return async_wrapper
1633+
1634+
1635+ def codeflash_concurrency_async(func):
1636+ @wraps(func)
1637+ async def async_wrapper(*args, **kwargs):
1638+ function_name = func.__name__
1639+ concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
1640+ test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
1641+ test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
1642+ test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
1643+ loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
1644+ gc.disable()
1645+ try:
1646+ seq_start = time.perf_counter_ns()
1647+ for _ in range(concurrency_factor):
1648+ result = await func(*args, **kwargs)
1649+ sequential_time = time.perf_counter_ns() - seq_start
1650+ finally:
1651+ gc.enable()
1652+ gc.disable()
1653+ try:
1654+ conc_start = time.perf_counter_ns()
1655+ tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
1656+ await asyncio.gather(*tasks)
1657+ concurrent_time = time.perf_counter_ns() - conc_start
1658+ finally:
1659+ gc.enable()
1660+ tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
1661+ print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
1662+ return result
1663+ return async_wrapper
1664+ """
1665+
1666+ ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
1667+
1668+
1669+ def get_decorator_name_for_mode (mode : TestingMode ) -> str :
1670+ if mode == TestingMode .BEHAVIOR :
1671+ return "codeflash_behavior_async"
1672+ if mode == TestingMode .CONCURRENCY :
1673+ return "codeflash_concurrency_async"
1674+ return "codeflash_performance_async"
1675+
1676+
1677+ def write_async_helper_file (target_dir : Path ) -> Path :
1678+ """Write the async decorator helper file to the target directory."""
1679+ helper_path = target_dir / ASYNC_HELPER_FILENAME
1680+ if not helper_path .exists ():
1681+ helper_path .write_text (ASYNC_HELPER_INLINE_CODE , "utf-8" )
1682+ return helper_path
15441683
15451684
15461685def add_async_decorator_to_function (
1547- source_path : Path , function : FunctionToOptimize , mode : TestingMode = TestingMode .BEHAVIOR
1686+ source_path : Path ,
1687+ function : FunctionToOptimize ,
1688+ mode : TestingMode = TestingMode .BEHAVIOR ,
1689+ project_root : Path | None = None ,
15481690) -> bool :
15491691 """Add async decorator to an async function definition and write back to file.
15501692
1551- Args:
1552- ----
1553- source_path: Path to the source file to modify in-place.
1554- function: The FunctionToOptimize object representing the target async function.
1555- mode: The testing mode to determine which decorator to apply.
1556-
1557- Returns:
1558- -------
1559- Boolean indicating whether the decorator was successfully added.
1693+ Writes a helper file containing the decorator implementation to project_root (or source directory
1694+ as fallback) and adds a standard import + decorator to the source file.
15601695
15611696 """
15621697 if not function .is_async :
15631698 return False
15641699
15651700 try :
1566- # Read source code
15671701 with source_path .open (encoding = "utf8" ) as f :
15681702 source_code = f .read ()
15691703
@@ -1573,10 +1707,14 @@ def add_async_decorator_to_function(
15731707 decorator_transformer = AsyncDecoratorAdder (function , mode )
15741708 module = module .visit (decorator_transformer )
15751709
1576- # Add the import if decorator was added
15771710 if decorator_transformer .added_decorator :
1578- import_transformer = AsyncDecoratorImportAdder (mode )
1579- module = module .visit (import_transformer )
1711+ # Write the helper file to project_root (on sys.path) or source dir as fallback
1712+ helper_dir = project_root if project_root is not None else source_path .parent
1713+ write_async_helper_file (helper_dir )
1714+ # Add the import via CST so sort_imports can place it correctly
1715+ decorator_name = get_decorator_name_for_mode (mode )
1716+ import_node = cst .parse_statement (f"from codeflash_async_wrapper import { decorator_name } " )
1717+ module = module .with_changes (body = [import_node , * list (module .body )])
15801718
15811719 modified_code = sort_imports (code = module .code , float_to_top = True )
15821720 except Exception as e :
0 commit comments