Skip to content

Commit 6d354f9

Browse files
committed
pr generator
1 parent df79dc5 commit 6d354f9

13 files changed

Lines changed: 642 additions & 57 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ __pycache__/
33
dist/
44
.venv/
55
.claude/
6+
.cursorrules
7+
AGENTS.md

ptq/agent.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,19 @@ def _stamp_worklog_header(
225225
)
226226

227227

228+
_AGENT_CONFIG_EXCLUDES = [".cursorrules", "AGENTS.md", ".claude/"]
229+
230+
231+
def _exclude_agent_configs(backend: Backend, worktree_path: str) -> None:
232+
exclude_file = f"{worktree_path}/.git/info/exclude"
233+
backend.run(f"mkdir -p $(dirname {exclude_file})", check=False)
234+
for pattern in _AGENT_CONFIG_EXCLUDES:
235+
backend.run(
236+
f"grep -qxF '{pattern}' {exclude_file} 2>/dev/null || echo '{pattern}' >> {exclude_file}",
237+
check=False,
238+
)
239+
240+
228241
def launch_agent(
229242
backend: Backend,
230243
*,
@@ -246,7 +259,7 @@ def launch_agent(
246259

247260
if existing_job_id:
248261
job_id = existing_job_id
249-
run_number = increment_run(job_id, agent_type=agent_type)
262+
run_number = increment_run(job_id, agent_type=agent_type, model=model)
250263
label = f"issue #{issue_number}" if issue_number else "adhoc"
251264
console.print(f"[bold]Job {job_id}[/bold] — {label} (run {run_number})")
252265
existing = job_id
@@ -259,7 +272,7 @@ def launch_agent(
259272
existing = find_existing_job(issue_number, machine=machine, local=local)
260273
if existing:
261274
job_id = existing
262-
run_number = increment_run(job_id, agent_type=agent_type)
275+
run_number = increment_run(job_id, agent_type=agent_type, model=model)
263276
console.print(
264277
f"[bold]Job {job_id}[/bold] — issue #{issue_number} (run {run_number})"
265278
)
@@ -285,6 +298,7 @@ def launch_agent(
285298
workspace=workspace,
286299
run_number=run_number,
287300
agent_type=agent_type,
301+
model=model,
288302
)
289303

290304
deploy_scripts(backend)
@@ -304,6 +318,7 @@ def launch_agent(
304318
else:
305319
console.print("Reusing existing worktree.")
306320

321+
_exclude_agent_configs(backend, worktree_path)
307322
with _timed("agent workspace setup"):
308323
agent.setup_workspace(backend, worktree_path, job_dir, workspace)
309324

ptq/apply.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,4 @@ def apply_diff(job_id: str, pytorch_path: Path) -> None:
8585
console.print(
8686
f"\n[bold green]Diff applied to {pytorch_path} on branch {branch_name}[/bold green]"
8787
)
88-
console.print("\n[bold]Copy & paste:[/bold]\n")
89-
if issue_number is not None:
90-
console.print(
91-
f"cd {pytorch_path} && git add -p && "
92-
f"git commit -m 'Fix #{issue_number}' && "
93-
f"gh pr create --title 'Fix #{issue_number}' "
94-
f"--body 'Fixes #{issue_number}'"
95-
)
96-
else:
97-
console.print(
98-
f"cd {pytorch_path} && git add -p && "
99-
f"git commit -m 'Fix from {job_id}' && "
100-
f"gh pr create --title 'Fix from {job_id}'"
101-
)
88+
console.print(f"\nTo create a PR, run: [bold]ptq pr {job_id}[/bold]")

ptq/cli.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,19 @@ def list_jobs() -> None:
403403
)
404404
console.print("[dim] ptq peek JOB_ID # check progress[/dim]")
405405
console.print("[dim] ptq results JOB_ID # fetch results[/dim]")
406+
console.print(
407+
"[dim] ptq pr JOB_ID # create GitHub PR[/dim]"
408+
)
406409
console.print("[dim] ptq kill JOB_ID # stop agent[/dim]")
407410
console.print(
408411
"[dim] ptq clean JOB_ID # remove job entirely[/dim]"
409412
)
410413
console.print(
411414
"[dim] ptq clean MACHINE # bulk clean stopped jobs[/dim]"
412415
)
416+
console.print(
417+
"[dim] ptq web # start web dashboard[/dim]"
418+
)
413419

414420

415421
@app.command()
@@ -520,6 +526,27 @@ def status(
520526
console.print(f"\n last log: [dim]{tail.stdout.strip()[:120]}[/dim]")
521527

522528

529+
@app.command()
530+
def pr(
531+
job_id: Annotated[str, typer.Argument(help="Job ID or issue number.")],
532+
title: Annotated[str | None, typer.Option(help="PR title override.")] = None,
533+
draft: Annotated[bool, typer.Option(help="Create as draft PR.")] = False,
534+
) -> None:
535+
"""Create a GitHub PR from a job's worktree changes."""
536+
from ptq.job import resolve_job_id
537+
from ptq.pr import create_pr
538+
539+
job_id = resolve_job_id(job_id)
540+
console.print(f"[bold]Creating PR for {job_id}[/bold]")
541+
result = create_pr(
542+
job_id,
543+
title=title,
544+
draft=draft,
545+
log=lambda msg: console.print(f" [dim]{msg}[/dim]"),
546+
)
547+
console.print(f"\n[bold green]PR created:[/bold green] {result.url}")
548+
549+
523550
@app.command()
524551
def web(
525552
port: Annotated[int, typer.Option(help="Port to listen on.")] = 8000,

ptq/config.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from __future__ import annotations
22

3+
import json
4+
import logging
5+
import re
6+
import subprocess
37
from dataclasses import dataclass, field
48
from pathlib import Path
59

610
import tomllib
711

12+
log = logging.getLogger("ptq.config")
13+
814
CONFIG_PATH = Path.home() / ".ptq" / "config.toml"
915

1016
_DEFAULT_TOML = """\
@@ -79,9 +85,77 @@ def _parse(data: dict) -> Config:
7985
)
8086

8187

88+
def discover_ssh_hosts() -> list[str]:
89+
ssh_config = Path.home() / ".ssh" / "config"
90+
if not ssh_config.exists():
91+
return []
92+
hosts: list[str] = []
93+
for line in ssh_config.read_text().splitlines():
94+
line = line.strip()
95+
if line.lower().startswith("host ") and "*" not in line:
96+
hosts.extend(line.split()[1:])
97+
return hosts
98+
99+
82100
def load_config(path: Path | None = None) -> Config:
83101
path = path or CONFIG_PATH
84102
if not path.exists():
85103
path.parent.mkdir(parents=True, exist_ok=True)
86104
path.write_text(_DEFAULT_TOML)
87105
return _parse(tomllib.loads(path.read_text()))
106+
107+
108+
_DISCOVER_CMDS: dict[str, list[str]] = {
109+
"claude": ["claude", "-p", "x", "--model", "__invalid__", "--max-turns", "0"],
110+
"codex": ["codex", "exec", "x", "--model", "__invalid__"],
111+
"cursor": ["agent", "-p", "x", "--model", "__invalid__", "--force"],
112+
}
113+
114+
_AVAILABLE_RE = re.compile(r"Available models?:\s*(.+)", re.IGNORECASE)
115+
116+
_MODEL_CACHE_FILES: dict[str, tuple[Path, str]] = {
117+
"codex": (Path.home() / ".codex" / "models_cache.json", "slug"),
118+
}
119+
120+
_discovered_cache: dict[str, list[str]] = {}
121+
122+
123+
def _read_cache_file(agent_name: str) -> list[str]:
124+
entry = _MODEL_CACHE_FILES.get(agent_name)
125+
if not entry:
126+
return []
127+
path, key = entry
128+
if not path.exists():
129+
return []
130+
try:
131+
data = json.loads(path.read_text())
132+
return [
133+
m[key] for m in data.get("models", []) if isinstance(m, dict) and key in m
134+
]
135+
except (json.JSONDecodeError, KeyError):
136+
return []
137+
138+
139+
def discover_models(agent_name: str) -> list[str]:
140+
if agent_name in _discovered_cache:
141+
return _discovered_cache[agent_name]
142+
143+
models: list[str] = []
144+
145+
cmd = _DISCOVER_CMDS.get(agent_name)
146+
if cmd:
147+
try:
148+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
149+
match = _AVAILABLE_RE.search(result.stdout + result.stderr)
150+
if match:
151+
models = [m.strip() for m in match.group(1).split(",") if m.strip()]
152+
except (subprocess.TimeoutExpired, FileNotFoundError):
153+
pass
154+
155+
if not models:
156+
models = _read_cache_file(agent_name)
157+
158+
_discovered_cache[agent_name] = models
159+
if models:
160+
log.debug("discovered %d models for %s", len(models), agent_name)
161+
return models

ptq/job.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ def register_job(
3737
workspace: str | None = None,
3838
run_number: int = 1,
3939
agent_type: str = "claude",
40+
model: str = "opus",
4041
) -> None:
4142
db = load_jobs_db()
42-
entry: dict = {"issue": issue_number, "runs": run_number, "agent": agent_type}
43+
entry: dict = {
44+
"issue": issue_number,
45+
"runs": run_number,
46+
"agent": agent_type,
47+
"model": model,
48+
}
4349
if local:
4450
entry["local"] = True
4551
entry["workspace"] = workspace or "~/.ptq_workspace"
@@ -50,14 +56,18 @@ def register_job(
5056
save_jobs_db(db)
5157

5258

53-
def increment_run(job_id: str, agent_type: str | None = None) -> int:
59+
def increment_run(
60+
job_id: str, agent_type: str | None = None, model: str | None = None
61+
) -> int:
5462
db = load_jobs_db()
5563
entry = db[job_id]
5664
run_number = entry.get("runs", 0) + 1
5765
entry["runs"] = run_number
5866
entry.pop("pid", None)
5967
if agent_type:
6068
entry["agent"] = agent_type
69+
if model:
70+
entry["model"] = model
6171
save_jobs_db(db)
6272
return run_number
6373

ptq/pr.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING
6+
7+
from ptq.job import get_job
8+
from ptq.ssh import backend_for_job
9+
10+
if TYPE_CHECKING:
11+
from ptq.ssh import Backend
12+
13+
14+
@dataclass
15+
class PRResult:
16+
url: str
17+
branch: str
18+
19+
20+
def _read_file(backend: Backend, path: str) -> str:
21+
result = backend.run(f"cat {path}", check=False)
22+
if result.returncode == 0:
23+
return result.stdout.strip()
24+
return ""
25+
26+
27+
def _build_pr_body(report: str, worklog: str, issue_number: int | None) -> str:
28+
parts: list[str] = []
29+
if report:
30+
parts.append(report)
31+
if issue_number is not None:
32+
parts.append(f"\n\nFixes #{issue_number}")
33+
if worklog:
34+
parts.append(
35+
f"\n\n<details>\n<summary>Worklog</summary>\n\n{worklog}\n\n</details>"
36+
)
37+
return "\n".join(parts) if parts else "Automated fix from ptq."
38+
39+
40+
_HTTPS_TO_SSH = {
41+
"https://github.com/": "git@github.com:",
42+
}
43+
44+
45+
def _ensure_ssh_remote(
46+
backend: Backend, worktree: str, _log: Callable[[str], None]
47+
) -> None:
48+
result = backend.run(f"cd {worktree} && git remote get-url origin", check=False)
49+
url = result.stdout.strip()
50+
for https_prefix, ssh_prefix in _HTTPS_TO_SSH.items():
51+
if url.startswith(https_prefix):
52+
ssh_url = url.replace(https_prefix, ssh_prefix)
53+
if not ssh_url.endswith(".git"):
54+
ssh_url += ".git"
55+
_log(f"Switching origin to SSH: {ssh_url}")
56+
backend.run(f"cd {worktree} && git remote set-url origin '{ssh_url}'")
57+
return
58+
59+
60+
def create_pr(
61+
job_id: str,
62+
*,
63+
title: str | None = None,
64+
draft: bool = False,
65+
log: Callable[[str], None] | None = None,
66+
) -> PRResult:
67+
_log = log or (lambda _: None)
68+
job = get_job(job_id)
69+
backend = backend_for_job(job_id)
70+
ws = backend.workspace
71+
job_dir = f"{ws}/jobs/{job_id}"
72+
worktree = f"{job_dir}/pytorch"
73+
issue_number = job.get("issue")
74+
75+
branch = f"ptq/{issue_number}" if issue_number is not None else f"ptq/{job_id}"
76+
pr_title = title or (
77+
f"Fix #{issue_number}" if issue_number is not None else f"Fix from {job_id}"
78+
)
79+
80+
_log(f"Branch: {branch}")
81+
_log(f"Title: {pr_title}")
82+
83+
report = _read_file(backend, f"{job_dir}/report.md")
84+
worklog = _read_file(backend, f"{job_dir}/worklog.md")
85+
body = _build_pr_body(report, worklog, issue_number)
86+
_log(
87+
f"PR body: report.md {'found' if report else 'missing'}, worklog.md {'found' if worklog else 'missing'}"
88+
)
89+
90+
_SCRUB_PATHS = ".claude/ .cursorrules AGENTS.md"
91+
92+
_log("Staging changes...")
93+
backend.run(f"cd {worktree} && git add -A")
94+
95+
commit_msg = pr_title.replace("'", "'\\''")
96+
_log(f"Creating branch {branch} (single commit)...")
97+
base = backend.run(
98+
f"cd {worktree} && git merge-base HEAD origin/main", check=False
99+
).stdout.strip()
100+
backend.run(f"cd {worktree} && git checkout -B '{branch}'")
101+
if base:
102+
backend.run(f"cd {worktree} && git reset --soft {base}")
103+
backend.run(f"cd {worktree} && git reset HEAD -- {_SCRUB_PATHS}", check=False)
104+
backend.run(
105+
f"cd {worktree} && git commit -m '{commit_msg}' --allow-empty",
106+
check=False,
107+
)
108+
109+
_ensure_ssh_remote(backend, worktree, _log)
110+
_log("Pushing and creating PR...")
111+
body_escaped = body.replace("'", "'\\''")
112+
result = backend.run(
113+
f"cd {worktree} && git push -u origin '{branch}' --force && "
114+
f"gh pr create "
115+
f"--title '{commit_msg}' "
116+
f"--body '{body_escaped}' "
117+
f"--head '{branch}'"
118+
f"{' --draft' if draft else ''}",
119+
check=False,
120+
)
121+
122+
url = ""
123+
for line in result.stdout.strip().splitlines():
124+
if line.startswith("http"):
125+
url = line.strip()
126+
break
127+
128+
if not url and result.returncode != 0:
129+
stderr = result.stderr.strip() if result.stderr else ""
130+
if "already exists" in stderr:
131+
list_result = backend.run(
132+
f"cd {worktree} && gh pr list --head '{branch}' --json url --jq '.[0].url'",
133+
check=False,
134+
)
135+
url = list_result.stdout.strip()
136+
if url:
137+
return PRResult(url=url, branch=branch)
138+
raise SystemExit(f"gh pr create failed: {stderr or result.stdout}")
139+
140+
return PRResult(url=url, branch=branch)

0 commit comments

Comments
 (0)