Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions treesearch/minimal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@
from treesearch.utils.response import strip_markdown_fences
from utils.log import _ROOT_LOGGER
from utils.path import mkdir
from utils.workspace_context import (
find_project_root,
find_workspace_dir,
get_workspace_context_block,
)

logger = _ROOT_LOGGER.getChild("nodeAgent")


class MinimalAgent:
"""A minimal agent class that only contains what's needed for processing nodes"""


def __init__(
self,
task_desc: str,
Expand All @@ -44,6 +50,14 @@ def __init__(
self.evaluation_metrics = evaluation_metrics
self.stage_name = stage_name
self._out_dir = mkdir(Path(cfg.out_dir))

# Dynamically discover project root and workspace directories
# (these can differ between Docker, pip, and local runs)
self._project_root = find_project_root()
self._workspace_dir = find_workspace_dir()
logger.info(f"Project root: {self._project_root}")
logger.info(f"Workspace dir: {self._workspace_dir}")

logger.info("Agent initialized!")

# Setup MCP connections for documentation search
Expand Down Expand Up @@ -77,8 +91,15 @@ def _prompt_environment(self):
]
pkg_str = ", ".join([f"`{p}`" for p in pkgs])

# Dynamically build workspace context so the LLM knows where files are
workspace_block = get_workspace_context_block(
project_root=self._project_root,
workspace_dir=self._workspace_dir,
)

env_prompt = {
"Installed Packages": f"Your solution can use the following machine learning packages: {pkg_str}. You MUST use these libraries as much as possible instead of implementing from scratch."
"Installed Packages": f"Your solution can use the following machine learning packages: {pkg_str}. You MUST use these libraries as much as possible instead of implementing from scratch.",
"Workspace Context": workspace_block,
}
return env_prompt

Expand All @@ -94,12 +115,20 @@ def _prompt_impl_guideline(self):
"4. Environment Setup:",
" - Create working directory: `working_dir = os.path.join(os.getcwd(), 'working'); os.makedirs(working_dir, exist_ok=True)`",
f" - Complete execution within {humanize.naturaldelta(self.cfg.exec.timeout)}",
"5. Data Tracking:",
"5. Data Loading:",
" - Use the `dataloader` package for all standard datasets:",
" ```python",
" from dataloader.loaders.registry import _run_loader",
" df = _run_loader('MovieLens100K') # Downloads & caches automatically",
" ```",
" - For custom data files: place them in the workspace directory and reference by filename",
" (code runs inside the workspace, so just use the filename directly)",
"6. Data Tracking:",
" - Track all relevant data points (e.g., metrics, losses)",
"6. Evaluation:",
"7. Evaluation:",
f" - Metrics: {', '.join(self.evaluation_metrics) if self.evaluation_metrics else 'Choose appropriate metrics'}",
" - Print metrics during execution for monitoring",
"7. API Verification (CRITICAL):",
"8. API Verification (CRITICAL):",
" - Check constructor signatures before use",
" - Verify object attributes exist (e.g., SplitData structure)",
" - Use only public APIs (no underscore-prefixed methods)",
Expand Down Expand Up @@ -287,6 +316,12 @@ def _new_node(self, plan: str, code: str, parent: Optional[Node] = None):

async def plan_and_code_query(self, prompt, retries=3) -> tuple[str, str]:
"""Generate a natural language plan + code in the same LLM call and split them apart."""
# Build workspace context block for the system prompt
workspace_block = get_workspace_context_block(
project_root=self._project_root,
workspace_dir=self._workspace_dir,
)

plan_and_code_result = (
await Query(tool_budget=40)
.with_mcp(self._mcp_docs)
Expand All @@ -305,7 +340,10 @@ async def plan_and_code_query(self, prompt, retries=3) -> tuple[str, str]:
"- Data structures and return types\n"
"\n"
"In 'nl_text', include '## Documentation Verified' section listing all verified methods.\n"
"Search for examples and Verify critical details in documentation."
"Search for examples and Verify critical details in documentation.\n"
"\n"
"IMPORTANT β€” File system awareness:\n"
f"{workspace_block}"
)
.run(prompt, PlanAndCode)
)
Expand All @@ -316,35 +354,52 @@ async def plan_and_code_query(self, prompt, retries=3) -> tuple[str, str]:

async def _select_datasets(self) -> list[str]:
"""Select appropriate datasets for the research task using LLM."""
# Check for any custom/local data files the user might be referencing
workspace_block = get_workspace_context_block(
project_root=self._project_root,
workspace_dir=self._workspace_dir,
)

prompt: Prompt = {
"Instruction:": (
f"You are a recommender system researcher selecting datasets for a research task.\n\n"
f"Research task:\n{self.task_desc}\n\n"
"Instructions:\n"
"1. Check if the research task specifies any datasets\n"
"2. If specified, select those datasets; otherwise choose appropriate ones from the list below\n"
"3. Return only a list of dataset identifiers\n\n"
f"Available datasets:\n{get_datasets_table()}"
"2. Check the 'Workspace Context' below β€” if a data file matching the task is found, "
"the user likely wants to use that file (e.g., if they mention 'movielens.csv' and it exists on disk)\n"
"3. If the task doesn't specify a dataset, choose appropriate ones from the available datasets below\n"
"4. Return ONLY dataset identifiers from the list below, OR the filename if a local file is to be used\n\n"
f"Workspace Context:\n{workspace_block}\n\n"
f"Available datasets (use these identifiers if no local file matches):\n{get_datasets_table()}"
)
}
result = (
await Query()
.with_mcp(self._mcp_docs)
.with_system(
"Search OmniRec documentation for dataset characteristics and usage patterns if needed."
"Search OmniRec documentation for dataset characteristics and usage patterns if needed. "
"If the user mentioned a dataset by name, check if there's a matching file on disk in the workspace context."
)
.run(prompt, SelectDatasets)
)
return result.selected_datasets

async def _set_code_requirements(self):
logger.info("Engineering code requirements...")
workspace_block = get_workspace_context_block(
project_root=self._project_root,
workspace_dir=self._workspace_dir,
)
requirements_prompt = f"""
You are an expert recommender systems researcher defining experiment requirements.

Research task: {self.task_desc}
Selected datasets: {self.selected_datasets}

Files available in workspace:
{workspace_block}

Generate requirements that specify critical aspects of the experiment that must be fulfilled.

PRINCIPLES:
Expand All @@ -356,11 +411,12 @@ async def _set_code_requirements(self):
2. Abstraction: State objectives and constraints at an appropriate level
- Avoid excessive implementation details (exact formulas, nested conditional logic, code-level instructions)
- Include critical technical specifications where they matter (framework to use, specific datasets, evaluation metrics, split ratios)
- If there are local data files, include a requirement about using the correct file path

3. Atomicity: Each requirement should test one distinct aspect of the experiment

4. Coverage: Include requirements for all essential aspects:
- Data loading and preprocessing
- Data loading and preprocessing (use correct paths for local files or dataloader for remote datasets)
- Experimental methodology (data splitting, reproducibility requirements)
- Model/algorithm selection and configuration β€” ALWAYS include a requirement that OmniRec must be used for all recommender system functionality; raw backend libraries (Lenskit, RecBole, etc.) must not be called directly
- Training procedures
Expand Down
203 changes: 203 additions & 0 deletions utils/workspace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""
Workspace context utilities.

Provides the LLM planner with awareness of what custom data files the user
placed in the workspace directory. Since the workspace can be in different
locations (Docker: /app/out/workspace, local: ./out/workspace),
this module dynamically discovers the correct paths.
"""
import os
from pathlib import Path
from typing import Optional

from utils.log import _ROOT_LOGGER

logger = _ROOT_LOGGER.getChild("workspace_context")


def find_project_root() -> Path:
"""
Find the project root directory by looking for sentinel files
(pyproject.toml, setup.py, .git).

Returns:
Path: The absolute path to the project root.
"""
cwd = Path.cwd().resolve()
for parent in [cwd] + list(cwd.parents):
sentinels = [
parent / "pyproject.toml",
parent / "setup.py",
parent / "setup.cfg",
parent / ".git",
parent / "main.py",
]
for sentinel in sentinels:
if sentinel.exists():
logger.debug(f"Found project root at {parent}")
return parent

logger.warning(f"Could not determine project root, falling back to {cwd}")
return cwd


def find_workspace_dir() -> Optional[Path]:
"""
Find the current workspace directory by checking common locations.

The workspace is where code gets executed (Interpreter writes runfile.py here).
It can be:
- ./out/workspace (when running locally via `uv run main.py`)
- /app/out/workspace (when running inside Docker)
- A custom path set via ARL_out_dir env var

Returns:
Optional[Path]: Path to the workspace directory, or None if not found.
"""
env_out_dir = os.environ.get("ARL_out_dir")
if env_out_dir:
candidate = Path(env_out_dir).resolve() / "workspace"
if candidate.exists():
return candidate

candidates = [
Path.cwd() / "out" / "workspace",
Path.cwd().resolve() / "out" / "workspace",
Path("/app/out/workspace"),
]

if Path("/app").exists():
candidates.append(Path("/app/out/workspace"))

for candidate in candidates:
if candidate.exists():
logger.debug(f"Found workspace at {candidate}")
return candidate

logger.debug("No workspace directory found yet (will be created during execution)")
return None


def scan_workspace_for_data_files(workspace_dir: Optional[Path] = None) -> str:
"""
Scan ONLY the workspace directory for user-placed data files.

Standard datasets should be loaded via the dataloader package.
Only custom/user-provided files (CSV, Parquet, JSON, etc.) placed
directly in the workspace are listed here.
Internal files (.pkl, .py, .log) are ignored automatically.

Args:
workspace_dir: The workspace directory to scan. If None, auto-detect.

Returns:
str: A formatted block describing available files, or empty string if none found.
"""
if workspace_dir is None:
workspace_dir = find_workspace_dir()

if workspace_dir is None or not workspace_dir.exists():
return ""

relevant_extensions = {
".csv": "CSV data file",
".tsv": "TSV data file",
".parquet": "Parquet data file",
".json": "JSON data file",
".jsonl": "JSONL data file",
}

found_files: list[tuple[Path, str]] = []

# Only scan the workspace directory itself (non-recursive)
# to avoid picking up runfile.py, working/, etc.
for item in workspace_dir.iterdir():
if item.is_file() and item.suffix.lower() in relevant_extensions:
found_files.append((item, relevant_extensions.get(item.suffix.lower(), "Data file")))

if not found_files:
return ""

lines = ["## Custom data files in workspace", ""]
for fpath, desc in sorted(found_files, key=lambda x: x[0].name):
size_str = _format_size(fpath.stat().st_size)
lines.append(f"- `{fpath.name}` β€” {desc} ({size_str})")
lines.append("")
lines.append("To use a custom file in your code, reference it by its filename")
lines.append("(the code runs inside the workspace directory):")
lines.append("```python")
lines.append("df = pd.read_csv('filename.csv')")
lines.append("```")
lines.append("")

return "\n".join(lines)


def _format_size(size_bytes: int) -> str:
"""Format byte size into human-readable string."""
for unit in ["B", "KB", "MB", "GB"]:
if size_bytes < 1024:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024
return f"{size_bytes:.1f} TB"


def get_workspace_context_block(
project_root: Optional[Path] = None,
workspace_dir: Optional[Path] = None,
) -> str:
"""
Get a formatted context block about the workspace that can be injected
into LLM prompts.

This tells the planner:
- Where the project root and workspace are
- How to use the dataloader package for standard datasets
- What custom data files the user placed in the workspace

Args:
project_root: Override for project root (auto-detected if None).
workspace_dir: Override for workspace dir (auto-detected if None).

Returns:
str: A context block to inject into prompts.
"""
if project_root is None:
project_root = find_project_root()
if workspace_dir is None:
workspace_dir = find_workspace_dir()

parts = [
"## Workspace & File Context",
"",
f"- **Project root:** `{project_root}`",
f"- **Code execution workspace:** `{workspace_dir if workspace_dir else 'Not yet created (will be ./out/workspace)'}`",
"",
"### How to access data files",
"",
"When your code runs, the working directory is the code execution workspace listed above.",
"",
"You have two options to load datasets:",
"",
"1. **Use the `dataloader` package** (recommended for standard datasets):",
" ```python",
" from dataloader.loaders.registry import _run_loader",
" df = _run_loader('MovieLens100K') # Downloads & caches automatically",
" ```",
" Available datasets include: MovieLens100K, MovieLens1M, MovieLens10M, MovieLens20M,",
" MovieLens25M, MovieLensLatest, MovieLensLatestSmall, MovieLens1BSynthetic,",
" Amazon2014*, Amazon2018*, Amazon2023*, Yelp2023, Gowalla, BeerAdvocate, etc.",
"",
"2. **Use custom files placed in the workspace** (listed below):",
"",
]

# Only scan the workspace for user-placed custom files
custom_files_block = scan_workspace_for_data_files(workspace_dir)
if custom_files_block:
parts.append(custom_files_block)
else:
parts.append(" *(No custom data files found in workspace.)*")
parts.append("")

return "\n".join(parts)