Skip to content

Commit 5f3480e

Browse files
authored
Merge pull request #1714 from codeflash-ai/testgen-review
feat: per-function test quality review and repair
2 parents 417623f + 2fec18c commit 5f3480e

17 files changed

Lines changed: 1091 additions & 148 deletions

.github/workflows/codeflash-optimize.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ jobs:
3939
- name: ⚡️Codeflash Optimization
4040
id: optimize_code
4141
run: |
42-
uv run codeflash --benchmark
42+
uv run codeflash --benchmark --testgen-review

codeflash/api/aiservice.py

Lines changed: 129 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1616
from codeflash.code_utils.time_utils import humanize_runtime
1717
from codeflash.languages import Language, current_language
18-
from codeflash.languages.current import current_language_support
1918
from codeflash.models.ExperimentMetadata import ExperimentMetadata
2019
from codeflash.models.models import (
2120
AIServiceRefinerRequest,
2221
CodeStringsMarkdown,
22+
FunctionRepairInfo,
2323
OptimizationReviewResult,
2424
OptimizedCandidate,
2525
OptimizedCandidateSource,
26+
TestFileReview,
2627
)
2728
from codeflash.telemetry.posthog_cf import ph
2829
from codeflash.version import __version__ as codeflash_version
@@ -57,13 +58,25 @@ def add_language_metadata(
5758
payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None
5859
) -> None:
5960
"""Add language version and module system metadata to an API payload."""
61+
from codeflash.languages.current import current_language_support
62+
6063
payload["python_version"] = platform.python_version()
6164
default_lang_version = current_language_support().default_language_version
6265
if default_lang_version is not None:
6366
payload["language_version"] = language_version or default_lang_version
6467
if module_system:
6568
payload["module_system"] = module_system
6669

70+
@staticmethod
71+
def log_error_response(response: requests.Response, action: str, ph_event: str) -> None:
72+
"""Log and report an API error response."""
73+
try:
74+
error = response.json()["error"]
75+
except Exception:
76+
error = response.text
77+
logger.error(f"Error {action}: {response.status_code} - {error}")
78+
ph(ph_event, {"response_status_code": response.status_code, "error": error})
79+
6780
def get_aiservice_base_url(self) -> str:
6881
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
6982
logger.info("Using local AI Service at http://localhost:8000")
@@ -95,14 +108,6 @@ def make_ai_service_request(
95108
------
96109
requests.exceptions.RequestException: If the request fails
97110
98-
"""
99-
"""Make an API request to the given endpoint on the AI service.
100-
101-
:param endpoint: The endpoint to call, e.g., "/optimize".
102-
:param method: The HTTP method to use ('GET' or 'POST').
103-
:param payload: Optional JSON payload to include in the POST request body.
104-
:param timeout: The timeout for the request.
105-
:return: The response object from the API.
106111
"""
107112
url = f"{self.base_url}/ai{endpoint}"
108113
if method.upper() == "POST":
@@ -213,12 +218,7 @@ def optimize_code(
213218
logger.info(f"!lsp|Received {len(optimizations_json)} optimization candidates.")
214219
console.rule()
215220
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE, language)
216-
try:
217-
error = response.json()["error"]
218-
except Exception:
219-
error = response.text
220-
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
221-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
221+
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
222222
console.rule()
223223
return []
224224

@@ -285,12 +285,7 @@ def get_jit_rewritten_code(self, source_code: str, trace_id: str) -> list[Optimi
285285
end_time = time.perf_counter()
286286
logger.debug(f"!lsp|Generating jit rewritten code took {end_time - start_time:.2f} seconds.")
287287
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.JIT_REWRITE)
288-
try:
289-
error = response.json()["error"]
290-
except Exception:
291-
error = response.text
292-
logger.error(f"Error generating jit rewritten candidate: {response.status_code} - {error}")
293-
ph("cli-jit-rewrite-error-response", {"response_status_code": response.status_code, "error": error})
288+
self.log_error_response(response, "generating jit rewritten candidate", "cli-jit-rewrite-error-response")
294289
console.rule()
295290
return []
296291

@@ -362,12 +357,7 @@ def optimize_python_code_line_profiler(
362357
logger.info(f"!lsp|Received {len(optimizations_json)} line profiler optimization candidates.")
363358
console.rule()
364359
return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP)
365-
try:
366-
error = response.json()["error"]
367-
except Exception:
368-
error = response.text
369-
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
370-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
360+
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
371361
console.rule()
372362
return []
373363

@@ -395,12 +385,7 @@ def adaptive_optimize(self, request: AIServiceAdaptiveOptimizeRequest) -> Optimi
395385

396386
return valid_candidates[0]
397387

398-
try:
399-
error = response.json()["error"]
400-
except Exception:
401-
error = response.text
402-
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
403-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
388+
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
404389
return None
405390

406391
def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
@@ -456,12 +441,7 @@ def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> li
456441

457442
return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE)
458443

459-
try:
460-
error = response.json()["error"]
461-
except Exception:
462-
error = response.text
463-
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
464-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
444+
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
465445
console.rule()
466446
return []
467447

@@ -508,12 +488,7 @@ def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate
508488

509489
return valid_candidates[0]
510490

511-
try:
512-
error = response.json()["error"]
513-
except Exception:
514-
error = response.text
515-
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
516-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
491+
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
517492
console.rule()
518493
return None
519494

@@ -608,12 +583,7 @@ def get_new_explanation(
608583
explanation: str = response.json()["explanation"]
609584
console.rule()
610585
return explanation
611-
try:
612-
error = response.json()["error"]
613-
except Exception:
614-
error = response.text
615-
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
616-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
586+
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
617587
console.rule()
618588
return ""
619589

@@ -660,12 +630,7 @@ def generate_ranking(
660630
ranking: list[int] = response.json()["ranking"]
661631
console.rule()
662632
return ranking
663-
try:
664-
error = response.json()["error"]
665-
except Exception:
666-
error = response.text
667-
logger.error(f"Error generating ranking: {response.status_code} - {error}")
668-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
633+
self.log_error_response(response, "generating ranking", "cli-optimize-error-response")
669634
console.rule()
670635
return None
671636

@@ -726,7 +691,7 @@ def generate_regression_tests(
726691
language_version: str | None = None,
727692
module_system: str | None = None,
728693
is_numerical_code: bool | None = None,
729-
) -> tuple[str, str, str] | None:
694+
) -> tuple[str, str, str, str | None] | None:
730695
"""Generate regression tests for the given function by making a request to the Django endpoint.
731696
732697
Parameters
@@ -749,6 +714,8 @@ def generate_regression_tests(
749714
750715
"""
751716
# Validate test framework based on language
717+
from codeflash.languages.current import current_language_support
718+
752719
lang_support = current_language_support()
753720
valid_frameworks = lang_support.valid_test_frameworks
754721
assert test_framework in valid_frameworks, (
@@ -779,6 +746,8 @@ def generate_regression_tests(
779746
try:
780747
response = self.make_ai_service_request("/testgen", payload=payload, timeout=self.timeout)
781748
except requests.exceptions.RequestException as e:
749+
from codeflash.telemetry.posthog_cf import ph
750+
782751
logger.exception(f"Error generating tests: {e}")
783752
ph("cli-testgen-error-caught", {"error": str(e)})
784753
return None
@@ -792,17 +761,111 @@ def generate_regression_tests(
792761
response_json["generated_tests"],
793762
response_json["instrumented_behavior_tests"],
794763
response_json["instrumented_perf_tests"],
764+
response_json.get("raw_generated_tests"),
795765
)
766+
self.log_error_response(response, "generating tests", "cli-testgen-error-response")
767+
return None
768+
769+
def review_generated_tests(
770+
self,
771+
tests: list[dict[str, Any]],
772+
function_source_code: str,
773+
function_name: str,
774+
trace_id: str,
775+
coverage_summary: str = "",
776+
coverage_details: dict[str, Any] | None = None,
777+
language: str = "python",
778+
) -> list[TestFileReview]:
779+
payload: dict[str, Any] = {
780+
"tests": tests,
781+
"function_source_code": function_source_code,
782+
"function_name": function_name,
783+
"trace_id": trace_id,
784+
"language": language,
785+
"codeflash_version": codeflash_version,
786+
"call_sequence": self.get_next_sequence(),
787+
}
788+
if coverage_summary:
789+
payload["coverage_summary"] = coverage_summary
790+
if coverage_details:
791+
payload["coverage_details"] = coverage_details
792+
self.add_language_metadata(payload)
796793
try:
797-
error = response.json()["error"]
798-
logger.error(f"Error generating tests: {response.status_code} - {error}")
799-
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
800-
return None
801-
except Exception:
802-
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
803-
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
794+
response = self.make_ai_service_request("/testgen_review", payload=payload, timeout=self.timeout)
795+
except requests.exceptions.RequestException as e:
796+
logger.exception(f"Error reviewing generated tests: {e}")
797+
ph("cli-testgen-review-error-caught", {"error": str(e)})
798+
return []
799+
800+
if response.status_code == 200:
801+
data = response.json()
802+
return [
803+
TestFileReview(
804+
test_index=r["test_index"],
805+
functions_to_repair=[
806+
FunctionRepairInfo(function_name=f["function_name"], reason=f.get("reason", ""))
807+
for f in r.get("functions", [])
808+
],
809+
)
810+
for r in data.get("reviews", [])
811+
]
812+
self.log_error_response(response, "reviewing generated tests", "cli-testgen-review-error-response")
813+
return []
814+
815+
def repair_generated_tests(
816+
self,
817+
test_source: str,
818+
functions_to_repair: list[FunctionRepairInfo],
819+
function_source_code: str,
820+
function_to_optimize: FunctionToOptimize,
821+
helper_function_names: list[str],
822+
module_path: Path,
823+
test_module_path: Path,
824+
test_framework: str,
825+
test_timeout: int,
826+
trace_id: str,
827+
language: str = "python",
828+
coverage_details: dict[str, Any] | None = None,
829+
previous_repair_errors: dict[str, str] | None = None,
830+
module_source_code: str = "",
831+
) -> tuple[str, str, str] | None:
832+
payload: dict[str, Any] = {
833+
"test_source": test_source,
834+
"functions_to_repair": [
835+
{"function_name": f.function_name, "reason": f.reason} for f in functions_to_repair
836+
],
837+
"function_source_code": function_source_code,
838+
"function_to_optimize": function_to_optimize,
839+
"helper_function_names": helper_function_names,
840+
"module_path": module_path,
841+
"test_module_path": test_module_path,
842+
"test_framework": test_framework,
843+
"test_timeout": test_timeout,
844+
"trace_id": trace_id,
845+
"language": language,
846+
"codeflash_version": codeflash_version,
847+
"call_sequence": self.get_next_sequence(),
848+
}
849+
if module_source_code:
850+
payload["module_source_code"] = module_source_code
851+
if coverage_details:
852+
payload["coverage_details"] = coverage_details
853+
if previous_repair_errors:
854+
payload["previous_repair_errors"] = previous_repair_errors
855+
self.add_language_metadata(payload)
856+
try:
857+
response = self.make_ai_service_request("/testgen_repair", payload=payload, timeout=self.timeout)
858+
except requests.exceptions.RequestException as e:
859+
logger.exception(f"Error repairing generated tests: {e}")
860+
ph("cli-testgen-repair-error-caught", {"error": str(e)})
804861
return None
805862

863+
if response.status_code == 200:
864+
data = response.json()
865+
return (data["generated_tests"], data["instrumented_behavior_tests"], data["instrumented_perf_tests"])
866+
self.log_error_response(response, "repairing generated tests", "cli-testgen-repair-error-response")
867+
return None
868+
806869
def get_optimization_review(
807870
self,
808871
original_code: dict[Path, str],
@@ -874,12 +937,7 @@ def get_optimization_review(
874937
return OptimizationReviewResult(
875938
review=cast("str", data["review"]), explanation=cast("str", data.get("review_explanation", ""))
876939
)
877-
try:
878-
error = cast("str", response.json()["error"])
879-
except Exception:
880-
error = response.text
881-
logger.error(f"Error generating optimization review: {response.status_code} - {error}")
882-
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
940+
self.log_error_response(response, "generating optimization review", "cli-optimize-error-response")
883941
console.rule()
884942
return OptimizationReviewResult(review="", explanation="")
885943

codeflash/cli_cmds/cli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def parse_args() -> Namespace:
106106
)
107107
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
108108
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
109+
parser.add_argument(
110+
"--testgen-review", default=False, action="store_true", help="Enable AI review and repair of generated tests"
111+
)
112+
parser.add_argument(
113+
"--testgen-review-turns", type=int, default=None, help="Number of review/repair cycles (default: 2)"
114+
)
109115
parser.add_argument(
110116
"--async",
111117
default=False,
@@ -201,12 +207,6 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
201207
if env_utils.is_ci():
202208
args.no_pr = True
203209

204-
if getattr(args, "async", False):
205-
logger.warning(
206-
"The --async flag is deprecated and will be removed in a future version. "
207-
"Async function optimization is now enabled by default."
208-
)
209-
210210
return args
211211

212212

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
COVERAGE_THRESHOLD = 60.0
2222
MIN_TESTCASE_PASSED_THRESHOLD = 6
2323
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
24+
MAX_TEST_REPAIR_CYCLES = 2
2425
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
2526

2627
# pytest loop stability

codeflash/languages/python/context/code_context_extractor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def get_function_sources_from_jedi(
439439
fully_qualified_name=fqn,
440440
only_function_name=func_name,
441441
source_code=definition.get_line_code(),
442+
definition_type=definition.type,
442443
)
443444
file_path_to_function_source[definition_path].add(function_source)
444445
function_source_list.append(function_source)

0 commit comments

Comments
 (0)