diff --git a/config.py b/config.py index 72cd81a..776b0cc 100644 --- a/config.py +++ b/config.py @@ -18,6 +18,8 @@ class ExecConfig(BaseSettings): enable_type_checking: bool = True max_type_check_attempts: int = 3 keep_only_relevant_files: bool = False + # maximum number of automatic debug/fix attempts per node + max_debug_retries_per_node: int = 2 class CodeConfig(BaseSettings): diff --git a/treesearch/node.py b/treesearch/node.py index 37a38e4..e1bdc4d 100644 --- a/treesearch/node.py +++ b/treesearch/node.py @@ -62,6 +62,9 @@ class Node(NodeMixin): type_check_passed: bool = field(default=False) type_check_results: list[TypeCheckResult] = field(default_factory=list) + # ---- debug retry limiting ---- + debug_attempts: int = field(default=0) + @property def name(self) -> str: short_id = f"{self.id[:4]}...{self.id[-4:]}" @@ -107,6 +110,8 @@ def __setstate__(self, state): """Set state during unpickling""" # Ensure all required attributes are present self.__dict__.update(state) + if "debug_attempts" not in self.__dict__: + self.debug_attempts = 0 def absorb_exec_result(self, exec_result: ExecutionResult): """Absorb the result of executing the code from this node.""" diff --git a/treesearch/search.py b/treesearch/search.py index 69b9bc8..012445f 100644 --- a/treesearch/search.py +++ b/treesearch/search.py @@ -69,26 +69,42 @@ def best_buggy_node(self): return buggy_nodes[0] def select_next_node(self) -> Node: - if ( - len(self.buggy_nodes) > 0 - and random.random() < self._config.treesearch.debug_prob - or len(self.good_nodes) == 0 - ): + # Exclude nodes that exhausted their debug attempts + limit = self._config.exec.max_debug_retries_per_node + eligible = [n for n in self.all_nodes if n.debug_attempts < limit] + + if not eligible: + logger.info("No eligible nodes available after applying debug-attempts filter.") + # fallback: return best good node if available, else best buggy node + if len(self.good_nodes) > 0: + return self.best_good_node + return self.best_buggy_node + + buggy = [n for n in eligible if n.is_buggy] + good = [n for n in eligible if not n.is_buggy] + + # Prefer debugging buggy nodes with probability debug_prob or when no good nodes exist + if (len(buggy) > 0 and random.random() < self._config.treesearch.debug_prob) or len(good) == 0: if random.random() < self._config.treesearch.epsilon: logger.info("Selecting random buggy node for debugging...") - nodes = self.buggy_nodes - weights = [1 / (len(n.children) + 1) for n in nodes] - return random.choices(nodes, weights=weights, k=1)[0] + weights = [1 / (len(n.children) + 1) for n in buggy] + return random.choices(buggy, weights=weights, k=1)[0] else: logger.info("Selecting best buggy node for debugging...") - return max(self.buggy_nodes, key=lambda n: n.score.score * (1 / (len(n.children) + 1))) + return max(buggy, key=lambda n: n.score.score * (1 / (len(n.children) + 1))) + + # Otherwise select a good node to improve + if len(good) == 0: + # no good nodes available; fallback + if len(self.good_nodes) > 0: + return self.best_good_node + return self.best_buggy_node if random.random() < self._config.treesearch.epsilon: - nodes = self.good_nodes - weights = [1 / (len(n.children) + 1) for n in nodes] - return random.choices(nodes, weights=weights, k=1)[0] + weights = [1 / (len(n.children) + 1) for n in good] + return random.choices(good, weights=weights, k=1)[0] else: - return max(self.good_nodes, key=lambda n: n.score.score * (1 / (len(n.children) + 1))) + return max(good, key=lambda n: n.score.score * (1 / (len(n.children) + 1))) async def run(self): logger.info("Starting tree search...") @@ -106,11 +122,32 @@ async def run(self): logger.info( f"Treesearch iteration {i + 1}/{self._config.treesearch.max_iterations}" ) + parent_node = self.select_next_node() - if parent_node.is_buggy: + if parent_node is None: + logger.info("select_next_node returned None — ending search early.") + break + + # Check if we can still debug this node or if we've hit the retry limit + max_debug_retries = self._config.exec.max_debug_retries_per_node + can_debug = ( + parent_node.is_buggy + and parent_node.debug_attempts < max_debug_retries + ) + + if can_debug: + logger.info( + f"Debugging node {parent_node.id[:8]}... " + f"(attempt {parent_node.debug_attempts + 1}/{max_debug_retries})" + ) child_node = await self._minimal_agent._debug(parent_node) + parent_node.debug_attempts += 1 else: + if parent_node.is_buggy and parent_node.debug_attempts >= max_debug_retries: + logger.info( + f"Node {parent_node.id[:8]} has reached max debug retries" + ) child_node = await self._minimal_agent._improve(parent_node) await self.exec_node(child_node)