From 9b9b7f863ccb57ad20b2e90e3e5efbb44b3daa13 Mon Sep 17 00:00:00 2001 From: asahoo Date: Mon, 29 Jun 2026 16:37:39 -0500 Subject: [PATCH 01/10] add initial code --- architecture/03-prediction-api.md | 1 + .../coglet/src/transport/http/playground.html | 1076 +++++++++++++++++ crates/coglet/src/transport/http/routes.rs | 34 +- examples/streaming-text/README.md | 4 +- examples/streaming-text/requirements.txt | 2 +- examples/streaming-text/run.py | 4 +- 6 files changed, 1115 insertions(+), 6 deletions(-) create mode 100644 crates/coglet/src/transport/http/playground.html diff --git a/architecture/03-prediction-api.md b/architecture/03-prediction-api.md index 057c83cee8..43c38a15d2 100644 --- a/architecture/03-prediction-api.md +++ b/architecture/03-prediction-api.md @@ -12,6 +12,7 @@ The Prediction API is the HTTP interface for running model inference. It uses a | `GET /health-check` | Health | Check server status | | `GET /` | Index | List available endpoints | | `GET /openapi.json` | Schema | OpenAPI specification | +| `GET /playground` | Playground | Browser UI for testing models | By default, `POST /predictions` blocks until completion. For long-running predictions, use async mode with `Prefer: respond-async` header -- the response returns immediately with status `processing`, and progress updates are delivered via webhook. diff --git a/crates/coglet/src/transport/http/playground.html b/crates/coglet/src/transport/http/playground.html new file mode 100644 index 0000000000..2ab2f9b7a1 --- /dev/null +++ b/crates/coglet/src/transport/http/playground.html @@ -0,0 +1,1076 @@ + + + + + +Cog Playground + + + +
+

Cog Playground

+ UNKNOWN + + Schema +
+ +
+ Setup Logs +

+
+ +
+
+

Inputs

+
+
+
+ + + + + +
+
+ +
+

Output

+
+
+
+
+ Logs +

+    
+
+
+
+ + + + diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 5904737391..266304b945 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -9,7 +9,7 @@ use axum::{ extract::{DefaultBodyLimit, Path, State}, http::{HeaderMap, StatusCode}, response::{ - IntoResponse, Json, Response, + Html, IntoResponse, Json, Response, sse::{Event, KeepAlive, Sse}, }, routing::{get, post, put}, @@ -121,6 +121,7 @@ async fn root(State(service): State>) -> Json>) -> impl I } } +const PLAYGROUND_HTML: &str = include_str!("playground.html"); + +async fn playground() -> impl IntoResponse { + Html(PLAYGROUND_HTML) +} + // Training routes — same dispatch as predictions but validated against // TrainingInput schema instead of Input. @@ -912,6 +919,7 @@ const MAX_HTTP_BODY_SIZE: usize = 100 * 1024 * 1024; pub fn routes(service: Arc) -> Router { Router::new() .route("/", get(root)) + .route("/playground", get(playground)) .route("/health-check", get(health_check)) .route("/openapi.json", get(openapi_schema)) .route("/shutdown", post(shutdown)) @@ -1807,6 +1815,7 @@ mod tests { // Without a python_sdk version set, falls back to coglet version assert_eq!(json["cog_version"], crate::version::COGLET_VERSION); assert_eq!(json["docs_url"], "/docs"); + assert_eq!(json["playground_url"], "/playground"); assert_eq!(json["openapi_url"], "/openapi.json"); assert_eq!(json["shutdown_url"], "/shutdown"); assert_eq!(json["healthcheck_url"], "/health-check"); @@ -1825,6 +1834,29 @@ mod tests { assert!(json.get("trainings_cancel_url").is_none()); } + #[tokio::test] + async fn playground_returns_html() { + let service = Arc::new(PredictionService::new_no_pool()); + let app = routes(service); + + let response = app + .oneshot(Request::get("/playground").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/html; charset=utf-8" + ); + + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + let html = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(html.contains("")); + assert!(html.contains("Cog Playground")); + } + #[tokio::test] async fn root_includes_training_urls_when_schema_has_training() { let service = Arc::new(PredictionService::new_no_pool()); diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md index 77265e6f1e..9d7953c4b4 100644 --- a/examples/streaming-text/README.md +++ b/examples/streaming-text/README.md @@ -9,7 +9,7 @@ This example shows how a Cog runner can yield text chunks as a model generates t From this directory: ```sh -cog predict -i prompt="Write a short haiku about databases" +cog run -i prompt="Write a short haiku about databases" ``` This returns the final accumulated output after the prediction completes. @@ -46,6 +46,6 @@ data: {"id":"streaming-demo","status":"succeeded",...} ## How it works -`predict.py` defines `run() -> Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. +`run.py` defines `run() -> Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. The normal prediction response still contains the accumulated output for compatibility. Requesting `Accept: text/event-stream` is useful when clients want to display tokens as they arrive. diff --git a/examples/streaming-text/requirements.txt b/examples/streaming-text/requirements.txt index a3ba48567c..334d2c30b9 100644 --- a/examples/streaming-text/requirements.txt +++ b/examples/streaming-text/requirements.txt @@ -1,3 +1,3 @@ -torch==2.12.0 +torch==2.8.0 transformers==5.0.0rc3 accelerate==1.6.0 diff --git a/examples/streaming-text/run.py b/examples/streaming-text/run.py index b51bee3332..cebde0be00 100644 --- a/examples/streaming-text/run.py +++ b/examples/streaming-text/run.py @@ -12,12 +12,12 @@ class Runner(BaseRunner): def setup(self) -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = torch.float16 if self.device == "cuda" else torch.float32 + dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) self.model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, - torch_dtype=dtype, + dtype=dtype, ).to(self.device) self.model.eval() From 89292d87a77fbe181e6d1598ed8f538ccc1249b8 Mon Sep 17 00:00:00 2001 From: asahoo Date: Tue, 30 Jun 2026 10:59:51 -0500 Subject: [PATCH 02/10] Add working version of playground --- architecture/03-prediction-api.md | 1 - .../coglet/src/transport/http/playground.html | 1076 ----------------- crates/coglet/src/transport/http/routes.rs | 34 +- docs/cli.md | 43 + docs/llms.txt | 43 + pkg/cli/playground-ui/api.js | 191 +++ pkg/cli/playground-ui/app.js | 521 ++++++++ pkg/cli/playground-ui/dom.js | 56 + pkg/cli/playground-ui/form.js | 303 +++++ pkg/cli/playground-ui/index.html | 655 ++++++++++ pkg/cli/playground-ui/media.js | 24 + pkg/cli/playground-ui/output.js | 137 +++ pkg/cli/playground-ui/schema.js | 122 ++ pkg/cli/playground-ui/theme.js | 16 + pkg/cli/playground.go | 323 +++++ pkg/cli/playground_test.go | 217 ++++ pkg/cli/root.go | 1 + 17 files changed, 2653 insertions(+), 1110 deletions(-) delete mode 100644 crates/coglet/src/transport/http/playground.html create mode 100644 pkg/cli/playground-ui/api.js create mode 100644 pkg/cli/playground-ui/app.js create mode 100644 pkg/cli/playground-ui/dom.js create mode 100644 pkg/cli/playground-ui/form.js create mode 100644 pkg/cli/playground-ui/index.html create mode 100644 pkg/cli/playground-ui/media.js create mode 100644 pkg/cli/playground-ui/output.js create mode 100644 pkg/cli/playground-ui/schema.js create mode 100644 pkg/cli/playground-ui/theme.js create mode 100644 pkg/cli/playground.go create mode 100644 pkg/cli/playground_test.go diff --git a/architecture/03-prediction-api.md b/architecture/03-prediction-api.md index 43c38a15d2..057c83cee8 100644 --- a/architecture/03-prediction-api.md +++ b/architecture/03-prediction-api.md @@ -12,7 +12,6 @@ The Prediction API is the HTTP interface for running model inference. It uses a | `GET /health-check` | Health | Check server status | | `GET /` | Index | List available endpoints | | `GET /openapi.json` | Schema | OpenAPI specification | -| `GET /playground` | Playground | Browser UI for testing models | By default, `POST /predictions` blocks until completion. For long-running predictions, use async mode with `Prefer: respond-async` header -- the response returns immediately with status `processing`, and progress updates are delivered via webhook. diff --git a/crates/coglet/src/transport/http/playground.html b/crates/coglet/src/transport/http/playground.html deleted file mode 100644 index 2ab2f9b7a1..0000000000 --- a/crates/coglet/src/transport/http/playground.html +++ /dev/null @@ -1,1076 +0,0 @@ - - - - - -Cog Playground - - - -
-

Cog Playground

- UNKNOWN - - Schema -
- -
- Setup Logs -

-
- -
-
-

Inputs

-
-
-
- - - - - -
-
- -
-

Output

-
-
-
-
- Logs -

-    
-
-
-
- - - - diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 266304b945..5904737391 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -9,7 +9,7 @@ use axum::{ extract::{DefaultBodyLimit, Path, State}, http::{HeaderMap, StatusCode}, response::{ - Html, IntoResponse, Json, Response, + IntoResponse, Json, Response, sse::{Event, KeepAlive, Sse}, }, routing::{get, post, put}, @@ -121,7 +121,6 @@ async fn root(State(service): State>) -> Json>) -> impl I } } -const PLAYGROUND_HTML: &str = include_str!("playground.html"); - -async fn playground() -> impl IntoResponse { - Html(PLAYGROUND_HTML) -} - // Training routes — same dispatch as predictions but validated against // TrainingInput schema instead of Input. @@ -919,7 +912,6 @@ const MAX_HTTP_BODY_SIZE: usize = 100 * 1024 * 1024; pub fn routes(service: Arc) -> Router { Router::new() .route("/", get(root)) - .route("/playground", get(playground)) .route("/health-check", get(health_check)) .route("/openapi.json", get(openapi_schema)) .route("/shutdown", post(shutdown)) @@ -1815,7 +1807,6 @@ mod tests { // Without a python_sdk version set, falls back to coglet version assert_eq!(json["cog_version"], crate::version::COGLET_VERSION); assert_eq!(json["docs_url"], "/docs"); - assert_eq!(json["playground_url"], "/playground"); assert_eq!(json["openapi_url"], "/openapi.json"); assert_eq!(json["shutdown_url"], "/shutdown"); assert_eq!(json["healthcheck_url"], "/health-check"); @@ -1834,29 +1825,6 @@ mod tests { assert!(json.get("trainings_cancel_url").is_none()); } - #[tokio::test] - async fn playground_returns_html() { - let service = Arc::new(PredictionService::new_no_pool()); - let app = routes(service); - - let response = app - .oneshot(Request::get("/playground").body(Body::empty()).unwrap()) - .await - .unwrap(); - - assert_eq!(response.status(), StatusCode::OK); - assert_eq!( - response.headers().get("content-type").unwrap(), - "text/html; charset=utf-8" - ); - - let body = response.into_body(); - let bytes = body.collect().await.unwrap().to_bytes(); - let html = String::from_utf8(bytes.to_vec()).unwrap(); - assert!(html.contains("")); - assert!(html.contains("Cog Playground")); - } - #[tokio::test] async fn root_includes_training_urls_when_schema_has_training() { let service = Arc::new(PredictionService::new_no_pool()); diff --git a/docs/cli.md b/docs/cli.md index 9f472fae32..cddcb4caf2 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -175,6 +175,49 @@ cog login [flags] --token-stdin Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token ``` +## `cog playground` + +Open a browser playground for talking to a running model. + +Starts a local web server that serves a schema-driven UI (a Postman-like tool +for Cog models). Point it at any running Cog HTTP API -- for example one started +with 'cog serve' -- and the playground reflects that model's inputs and outputs +from its OpenAPI schema in real time. + +Requests are reverse-proxied through this server, so the target API does not +need to set CORS headers. The server also hosts a webhook sink so async +predictions can be observed in the browser. + +Async/webhook testing against a containerized model requires the webhook URL to +be reachable from inside the container. On Docker Desktop the default +'host.docker.internal' works once the server listens on a reachable interface +(e.g. --host 0.0.0.0). + +``` +cog playground [flags] +``` + +**Examples** + +``` + # Start a model API in one terminal + cog serve -p 8393 + + # Open the playground pointing at it + cog playground --target http://localhost:8393 +``` + +**Options** + +``` + -h, --help help for playground + --host string Address to bind (use 0.0.0.0 to receive webhooks from containers) (default "127.0.0.1") + --no-open Do not open the browser automatically + -p, --port int Port to listen on (0 picks a free port) + --target string Default target model API URL (default "http://localhost:8393") + --webhook-host string Hostname the model uses to reach this server for webhooks (default "host.docker.internal") +``` + ## `cog push` Build a Docker image from cog.yaml and push it to a container registry. diff --git a/docs/llms.txt b/docs/llms.txt index 11cb30b033..6825d53bd6 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -421,6 +421,49 @@ cog login [flags] --token-stdin Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token ``` +## `cog playground` + +Open a browser playground for talking to a running model. + +Starts a local web server that serves a schema-driven UI (a Postman-like tool +for Cog models). Point it at any running Cog HTTP API -- for example one started +with 'cog serve' -- and the playground reflects that model's inputs and outputs +from its OpenAPI schema in real time. + +Requests are reverse-proxied through this server, so the target API does not +need to set CORS headers. The server also hosts a webhook sink so async +predictions can be observed in the browser. + +Async/webhook testing against a containerized model requires the webhook URL to +be reachable from inside the container. On Docker Desktop the default +'host.docker.internal' works once the server listens on a reachable interface +(e.g. --host 0.0.0.0). + +``` +cog playground [flags] +``` + +**Examples** + +``` + # Start a model API in one terminal + cog serve -p 8393 + + # Open the playground pointing at it + cog playground --target http://localhost:8393 +``` + +**Options** + +``` + -h, --help help for playground + --host string Address to bind (use 0.0.0.0 to receive webhooks from containers) (default "127.0.0.1") + --no-open Do not open the browser automatically + -p, --port int Port to listen on (0 picks a free port) + --target string Default target model API URL (default "http://localhost:8393") + --webhook-host string Hostname the model uses to reach this server for webhooks (default "host.docker.internal") +``` + ## `cog push` Build a Docker image from cog.yaml and push it to a container registry. diff --git a/pkg/cli/playground-ui/api.js b/pkg/cli/playground-ui/api.js new file mode 100644 index 0000000000..29df38a001 --- /dev/null +++ b/pkg/cli/playground-ui/api.js @@ -0,0 +1,191 @@ +// CogApi talks to the target model only through the playground's own reverse +// proxy (same origin). Every request carries the chosen target base URL in the +// X-Cog-Target header; the Go server forwards it. This avoids CORS (Cog sets +// none) and keeps SSE streaming working. + +const PROXY_PREFIX = "/proxy"; + +export class CogApi { + constructor() { + this.target = ""; + } + + setTarget(url) { + this.target = (url || "").trim().replace(/\/+$/, ""); + } + + _headers(extra) { + return Object.assign({ "X-Cog-Target": this.target }, extra || {}); + } + + _url(endpoint, id) { + return id + ? `${PROXY_PREFIX}${endpoint}/${encodeURIComponent(id)}` + : PROXY_PREFIX + endpoint; + } + + _body(input, webhook, webhookFilter) { + const body = { input }; + if (webhook) { + body.webhook = webhook; + body.webhook_events_filter = webhookFilter; + } + return body; + } + + // getConfig returns playground server config (e.g. the webhook base URL). + async getConfig() { + try { + const r = await fetch("/config"); + if (r.ok) return r.json(); + } catch { + /* ignore */ + } + return {}; + } + + async health() { + const r = await fetch(PROXY_PREFIX + "/health-check", { + headers: this._headers(), + }); + if (!r.ok) throw new Error("HTTP " + r.status); + return r.json(); + } + + async schema() { + const r = await fetch(PROXY_PREFIX + "/openapi.json", { + headers: this._headers(), + }); + if (!r.ok) throw new Error("HTTP " + r.status); + return r.json(); + } + + // submit runs a prediction/training in blocking (sync) or async mode. A + // non-empty `id` makes the request idempotent (PUT). Returns the response + // envelope (the 202 acknowledgement in async mode). + async submit({ endpoint, id, input, asyncMode, webhook, webhookFilter, signal }) { + const headers = this._headers({ "Content-Type": "application/json" }); + if (asyncMode) headers["Prefer"] = "respond-async"; + const r = await fetch(this._url(endpoint, id), { + method: id ? "PUT" : "POST", + headers, + body: JSON.stringify(this._body(input, webhook, webhookFilter)), + signal, + }); + const body = await r.json().catch(() => ({})); + if (!r.ok) throw httpError(r.status, body); + return body; + } + + // stream runs a prediction in SSE mode, yielding parsed { type, data } events. + async *stream({ endpoint, id, input, webhook, webhookFilter, signal }) { + const resp = await fetch(this._url(endpoint, id), { + method: id ? "PUT" : "POST", + headers: this._headers({ + "Content-Type": "application/json", + Accept: "text/event-stream", + }), + body: JSON.stringify(this._body(input, webhook, webhookFilter)), + signal, + }); + if (!resp.ok) { + const text = await resp.text(); + let body = {}; + try { + body = JSON.parse(text); + } catch { + /* not JSON */ + } + throw httpError(resp.status, body, text); + } + + const reader = resp.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + for (;;) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + let sep; + while ((sep = buffer.indexOf("\n\n")) >= 0) { + const raw = buffer.slice(0, sep); + buffer = buffer.slice(sep + 2); + const event = parseSSEEvent(raw); + if (event) { + event.raw = raw; + yield event; + } + } + } + if (buffer.trim()) { + const event = parseSSEEvent(buffer); + if (event) { + event.raw = buffer; + yield event; + } + } + } + + // cancel requests cancellation of an in-flight prediction/training by id. + async cancel(endpoint, id) { + await fetch(`${PROXY_PREFIX}${endpoint}/${encodeURIComponent(id)}/cancel`, { + method: "POST", + headers: this._headers(), + }); + } +} + +// httpError builds an Error from a non-2xx response, attaching the structured +// `detail` array (422 validation errors) when present so callers can render +// field-level messages. +function httpError(status, body, fallbackText) { + const detail = body && Array.isArray(body.detail) ? body.detail : null; + const message = + (body && (body.error || (typeof body.detail === "string" ? body.detail : null))) || + fallbackText || + "HTTP " + status; + const err = new Error(message); + err.status = status; + if (detail) err.detail = detail; + return err; +} + +// parseSSEEvent parses one "event: ...\ndata: ..." block. The data payload is +// JSON-decoded when possible. +export function parseSSEEvent(block) { + let eventType = ""; + const dataLines = []; + for (const line of block.split("\n")) { + if (line.startsWith("event:")) { + eventType = line.slice(6).trim(); + } else if (line.startsWith("data:")) { + dataLines.push(line.slice(5).replace(/^ /, "")); + } + } + if (!eventType) return null; + const dataStr = dataLines.join("\n"); + let data = dataStr; + try { + data = JSON.parse(dataStr); + } catch { + /* keep raw string */ + } + return { type: eventType, data }; +} + +// fileToDataURI reads a File into a base64 data: URI suitable for a cog.Path +// input. +export function fileToDataURI(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result); + reader.onerror = reject; + reader.readAsDataURL(file); + }); +} + +export function formatBytes(bytes) { + if (bytes < 1024) return bytes + " B"; + if (bytes < 1048576) return (bytes / 1024).toFixed(1) + " KB"; + return (bytes / 1048576).toFixed(1) + " MB"; +} diff --git a/pkg/cli/playground-ui/app.js b/pkg/cli/playground-ui/app.js new file mode 100644 index 0000000000..00ffef9cc5 --- /dev/null +++ b/pkg/cli/playground-ui/app.js @@ -0,0 +1,521 @@ +import { CogApi } from "./api.js"; +import { buildForm } from "./form.js"; +import { resolveRef, defaultInput } from "./schema.js"; +import { toggleTheme, currentTheme } from "./theme.js"; +import { + setBadge, + showError, + clearError, + renderValidationErrors, + renderMetrics, + renderOutput, + renderText, + renderRaw, +} from "./output.js"; + +const DEFAULT_TARGET = "http://localhost:8393"; +const STORAGE_KEY = "cog-playground-target"; +const TERMINAL = ["succeeded", "failed", "canceled"]; + +const api = new CogApi(); + +const state = { + schema: null, + inputSchema: null, + outputSchema: null, + supportsStreaming: false, + supportsAsync: false, + form: null, + mode: "form", // "form" | "json" + runMode: "sync", // "sync" | "stream" | "async" + running: false, + abort: null, + eventSource: null, + lastId: null, + metrics: {}, + loadToken: 0, + healthTimer: null, + showLive: false, // a run is driving the toggled output view + outputValue: null, // current output (array of chunks, scalar, or object) + rawEvents: [], // raw frames/payloads exactly as received, for the Raw view + outputView: "text", // "text" | "raw" + webhookBase: "", +}; + +const dom = {}; +const DOM_IDS = [ + "health-badge", "version-info", "target-url", "connect-btn", "target-status", + "schema-link", "theme-toggle", "setup-panel", "setup-status", "setup-logs", + "schema-error", "form-container", "json-container", "json-editor", + "json-error", "json-format", "mode-form", "mode-json", "run-mode", + "run-mode-sync", "run-mode-stream", "run-mode-async", "prediction-id", + "webhook-options", "webhook-base-note", "stream-hint", "run-btn", "stop-btn", + "reset-btn", "result-status", "output-view", "output-view-text", + "output-view-raw", "metrics-container", "error-container", "output-container", +]; + +async function init() { + for (const id of DOM_IDS) dom[id] = document.getElementById(id); + + const params = new URLSearchParams(location.search); + dom["target-url"].value = + params.get("target") || localStorage.getItem(STORAGE_KEY) || DEFAULT_TARGET; + refreshThemeLabel(); + + const config = await api.getConfig(); + state.webhookBase = config.webhookBase || ""; + + dom["connect-btn"].addEventListener("click", connect); + dom["target-url"].addEventListener("keydown", (e) => { + if (e.key === "Enter") connect(); + }); + dom["theme-toggle"].addEventListener("click", () => { + toggleTheme(); + refreshThemeLabel(); + }); + dom["run-btn"].addEventListener("click", run); + dom["stop-btn"].addEventListener("click", stop); + dom["reset-btn"].addEventListener("click", reset); + dom["mode-form"].addEventListener("click", () => setMode("form")); + dom["mode-json"].addEventListener("click", () => setMode("json")); + dom["json-format"].addEventListener("click", formatJSON); + dom["run-mode-sync"].addEventListener("click", () => setRunMode("sync")); + dom["run-mode-stream"].addEventListener("click", () => setRunMode("stream")); + dom["run-mode-async"].addEventListener("click", () => setRunMode("async")); + dom["output-view-text"].addEventListener("click", () => setOutputView("text")); + dom["output-view-raw"].addEventListener("click", () => setOutputView("raw")); + + connect(); +} + +function refreshThemeLabel() { + dom["theme-toggle"].textContent = currentTheme() === "dark" ? "Light" : "Dark"; +} + +function connect() { + const url = dom["target-url"].value.trim(); + if (!url) return; + api.setTarget(url); + localStorage.setItem(STORAGE_KEY, url); + dom["schema-link"].href = + "/proxy/openapi.json?cog_target=" + encodeURIComponent(url); + history.replaceState(null, "", "?target=" + encodeURIComponent(url)); + + startHealthPolling(); + loadSchema(); +} + +function startHealthPolling() { + if (state.healthTimer) clearInterval(state.healthTimer); + pollHealth(); + state.healthTimer = setInterval(pollHealth, 5000); +} + +async function pollHealth() { + try { + const data = await api.health(); + setBadge(dom["health-badge"], data.status); + dom["target-status"].textContent = data.user_healthcheck_error || ""; + updateSetup(data.setup); + updateVersion(data.version); + } catch { + setBadge(dom["health-badge"], "unreachable"); + dom["target-status"].textContent = "target unreachable"; + } +} + +function updateSetup(setup) { + if (!setup) { + dom["setup-panel"].hidden = true; + return; + } + dom["setup-panel"].hidden = false; + setBadge(dom["setup-status"], setup.status); + dom["setup-logs"].textContent = setup.logs || ""; +} + +function updateVersion(version) { + if (!version) return; + const parts = []; + if (version.coglet) parts.push("coglet " + version.coglet); + if (version.cog) parts.push("cog " + version.cog); + if (version.python) parts.push("py " + version.python); + dom["version-info"].textContent = parts.join(" · "); +} + +async function loadSchema() { + const token = ++state.loadToken; + try { + const schema = await api.schema(); + if (token !== state.loadToken) return; // superseded by a newer connect + applySchema(schema); + dom["schema-error"].classList.remove("visible"); + } catch (err) { + if (token !== state.loadToken) return; + showError(dom["schema-error"], "Waiting for schema… (" + err.message + ")"); + setTimeout(() => { + if (token === state.loadToken) loadSchema(); + }, 3000); + } +} + +function applySchema(schema) { + state.schema = schema; + const schemas = (schema.components || {}).schemas || {}; + const paths = schema.paths || {}; + + state.inputSchema = resolveRef(schema, schemas.Input || {}); + state.outputSchema = resolveRef(schema, schemas.Output || {}); + state.supportsStreaming = + ((paths["/predictions"] || {}).post || {})["x-cog-streaming"] === true; + // Async predictions are observed via webhooks and cancelled via the cancel + // endpoint; treat the presence of that endpoint as the async-capable signal. + state.supportsAsync = !!paths["/predictions/{prediction_id}/cancel"]; + + state.runMode = state.supportsStreaming ? "stream" : "sync"; + configureRunModes(); + rebuildForm(defaultInput(schema, state.inputSchema)); +} + +// configureRunModes shows only the run modes the model advertises. +function configureRunModes() { + dom["run-mode-stream"].hidden = !state.supportsStreaming; + dom["run-mode-async"].hidden = !state.supportsAsync; + dom["run-mode"].hidden = !(state.supportsStreaming || state.supportsAsync); + + const available = { sync: true, stream: state.supportsStreaming, async: state.supportsAsync }; + if (!available[state.runMode]) state.runMode = "sync"; + + const out = state.outputSchema || {}; + const isIterator = + out["x-cog-array-type"] === "iterator" || out["x-cog-array-display"] === "concatenate"; + dom["stream-hint"].textContent = + !state.supportsStreaming && isIterator + ? "Add @cog.streaming to run() for real-time output" + : ""; + + updateRunModeButtons(); +} + +function updateRunModeButtons() { + for (const m of ["sync", "stream", "async"]) { + dom["run-mode-" + m].classList.toggle("active", state.runMode === m); + } + dom["webhook-options"].hidden = state.runMode !== "async"; + dom["webhook-base-note"].textContent = state.webhookBase + ? "Webhook: " + state.webhookBase + "/webhook/…" + : "No webhook host configured (set --webhook-host)."; +} + +function setRunMode(mode) { + if (dom["run-mode-" + mode].hidden) return; + state.runMode = mode; + updateRunModeButtons(); +} + +// --- input mode toggle (Form vs JSON) --- +function setMode(mode) { + if (mode === state.mode) return; + if (mode === "json") { + syncFormToJSON(); + } else { + const parsed = parseEditor(); + if (parsed === undefined) return; // invalid JSON: stay in JSON mode + rebuildForm(parsed); + } + state.mode = mode; + dom["mode-form"].classList.toggle("active", mode === "form"); + dom["mode-json"].classList.toggle("active", mode === "json"); + dom["form-container"].hidden = mode !== "form"; + dom["json-container"].hidden = mode !== "json"; +} + +function rebuildForm(value) { + state.form = buildForm(dom["form-container"], state.schema, state.inputSchema, value); + if (state.mode === "json") syncFormToJSON(); +} + +function syncFormToJSON() { + const value = state.form ? state.form.collect() : {}; + dom["json-editor"].value = JSON.stringify(value, null, 2); + dom["json-error"].textContent = ""; +} + +function parseEditor() { + const raw = dom["json-editor"].value.trim(); + if (raw === "") { + dom["json-error"].textContent = ""; + return {}; + } + try { + const parsed = JSON.parse(raw); + dom["json-error"].textContent = ""; + return parsed; + } catch (err) { + dom["json-error"].textContent = "Invalid JSON: " + err.message; + return undefined; + } +} + +function formatJSON() { + const parsed = parseEditor(); + if (parsed === undefined) return; + dom["json-editor"].value = JSON.stringify(parsed, null, 2); +} + +// --- output view toggle (Text vs Raw) --- +function setOutputView(view) { + state.outputView = view; + dom["output-view-text"].classList.toggle("active", view === "text"); + dom["output-view-raw"].classList.toggle("active", view === "raw"); + if (state.showLive) renderLive(); +} + +function showOutputView(visible) { + dom["output-view"].hidden = !visible; +} + +// renderLive renders the current output in the selected view. Raw shows the +// exact frames/payloads; Text concatenates plain-string output (the streaming +// "adds up" view) and falls back to the rich renderer for media/structured +// output so non-text results still display. +function renderLive() { + const container = dom["output-container"]; + if (state.outputView === "raw") { + renderRaw(container, state.rawEvents, state.running); + return; + } + const value = state.outputValue; + if (value == null) { + renderText(container, "", state.running); + } else if (isPlainText(value)) { + renderText(container, value, state.running); + } else if (Array.isArray(value) && value.length > 0 && value.every(isPlainText)) { + renderText(container, value.join(""), state.running); + } else { + renderOutput(container, value, state.outputSchema); + } +} + +// isPlainText is true for a string that isn't a media/URL reference, i.e. one +// that should be concatenated as text rather than rendered as media. +function isPlainText(x) { + return typeof x === "string" && !x.startsWith("data:") && !/^https?:\/\//i.test(x); +} + +// --- running --- +function activeInput() { + return state.mode === "json" ? parseEditor() : state.form.collect(); +} + +function currentId() { + const id = dom["prediction-id"].value.trim(); + return id || undefined; +} + +function run() { + if (state.running) return; + const input = activeInput(); + if (input === undefined) return; // invalid JSON + + clearError(dom["error-container"]); + renderOutput(dom["output-container"], null); + renderMetrics(dom["metrics-container"], {}); + dom["result-status"].textContent = ""; + state.metrics = {}; + state.outputValue = null; + state.rawEvents = []; + state.lastId = currentId() || null; + + setRunning(true); + state.abort = new AbortController(); + + if (state.runMode === "async") runAsync(input); + else if (state.runMode === "stream") runStream(input); + else runSync(input); +} + +async function runSync(input) { + state.showLive = true; + showOutputView(true); + setBadge(dom["result-status"], "processing"); + try { + const response = await api.submit({ + endpoint: "/predictions", + id: currentId(), + input, + signal: state.abort.signal, + }); + state.lastId = response.id || state.lastId; + applyEnvelope(response); + state.outputValue = response.error ? null : response.output ?? null; + state.rawEvents = [JSON.stringify(response, null, 2)]; + } catch (err) { + reportRunError(err); + } finally { + setRunning(false); + renderLive(); + } +} + +async function runStream(input) { + state.showLive = true; + state.outputValue = []; + showOutputView(true); + setBadge(dom["result-status"], "processing"); + try { + for await (const event of api.stream({ + endpoint: "/predictions", + id: currentId(), + input, + signal: state.abort.signal, + })) { + if (event.raw != null) state.rawEvents.push(event.raw); + handleStreamEvent(event); + renderLive(); + } + } catch (err) { + reportRunError(err); + } finally { + setRunning(false); + renderLive(); // final render without the streaming cursor + } +} + +// runAsync submits with Prefer: respond-async and a webhook pointing at the +// playground server's sink, then observes delivered events over /events (SSE). +async function runAsync(input) { + state.showLive = true; + showOutputView(true); + setBadge(dom["result-status"], "starting"); + + const token = crypto.randomUUID(); + const webhook = state.webhookBase ? `${state.webhookBase}/webhook/${token}` : null; + if (webhook) { + const es = new EventSource("/events?token=" + token); + state.eventSource = es; + es.onmessage = (e) => { + state.rawEvents.push(e.data); + let data; + try { + data = JSON.parse(e.data); + } catch { + renderLive(); + return; + } + applyEnvelope(data); + if (!data.error && data.output != null) { + state.outputValue = data.output; + } + renderLive(); + if (TERMINAL.includes(data.status)) finishAsync(); + }; + } + + try { + const response = await api.submit({ + endpoint: "/predictions", + id: currentId(), + input, + asyncMode: true, + webhook, + webhookFilter: collectWebhookFilter(), + signal: state.abort.signal, + }); + state.lastId = response.id || state.lastId; + setBadge(dom["result-status"], response.status || "starting"); + if (!webhook) finishAsync(); // nothing to observe; stop the spinner + } catch (err) { + reportRunError(err); + finishAsync(); + } +} + +function finishAsync() { + if (state.eventSource) { + state.eventSource.close(); + state.eventSource = null; + } + setRunning(false); + renderLive(); +} + +function collectWebhookFilter() { + return Array.from(document.querySelectorAll(".wh-filter:checked")).map((c) => c.value); +} + +function handleStreamEvent(event) { + const data = event.data; + switch (event.type) { + case "start": + setBadge(dom["result-status"], data.status || "starting"); + break; + case "output": + if (!Array.isArray(state.outputValue)) state.outputValue = []; + state.outputValue.push(data.chunk); + break; + case "metric": + state.metrics[data.name] = data.value; + renderMetrics(dom["metrics-container"], state.metrics); + break; + case "error": + // Transport-level SSE error (e.g. replay truncated, broadcast lagged). + showError(dom["error-container"], data.error || "stream error"); + break; + case "completed": + applyEnvelope(data); + break; + } +} + +// applyEnvelope updates status/metrics/error from a prediction envelope (shared +// by sync responses, the streamed "completed" event, and webhooks). +function applyEnvelope(data) { + if (!data) return; + setBadge(dom["result-status"], data.status || "unknown"); + if (data.metrics) renderMetrics(dom["metrics-container"], data.metrics); + if (data.error) showError(dom["error-container"], data.error); +} + +function reportRunError(err) { + if (err.name === "AbortError") { + setBadge(dom["result-status"], "canceled"); + return; + } + if (err.detail) { + renderValidationErrors(dom["error-container"], err.detail); + } else { + showError(dom["error-container"], err.message); + } + setBadge(dom["result-status"], "failed"); +} + +function setRunning(running) { + state.running = running; + dom["run-btn"].disabled = running; + dom["stop-btn"].disabled = !running; + dom["reset-btn"].disabled = running; +} + +// stop aborts the local request/stream and, if we know the prediction id, asks +// the model to cancel it (the only way to stop an async prediction). +function stop() { + if (state.abort) state.abort.abort(); + if (state.lastId) api.cancel("/predictions", state.lastId).catch(() => {}); + finishAsync(); +} + +function reset() { + if (state.running || !state.schema) return; + clearError(dom["error-container"]); + renderOutput(dom["output-container"], null); + renderMetrics(dom["metrics-container"], {}); + dom["result-status"].textContent = ""; + state.showLive = false; + state.outputValue = null; + state.rawEvents = []; + showOutputView(false); + rebuildForm(defaultInput(state.schema, state.inputSchema)); +} + +document.addEventListener("DOMContentLoaded", init); diff --git a/pkg/cli/playground-ui/dom.js b/pkg/cli/playground-ui/dom.js new file mode 100644 index 0000000000..ea75d81620 --- /dev/null +++ b/pkg/cli/playground-ui/dom.js @@ -0,0 +1,56 @@ +// Minimal DOM helpers — a thin wrapper over document.createElement so building +// nodes reads top-to-bottom without a framework or a build step. +// +// el("div", { class: "field" }, el("label", { text: "name" }), input) +// +// Props: `class` -> className, `text` -> textContent, `html` is intentionally +// unsupported (we never inject untrusted HTML). `onclick`/`oninput`/... attach +// listeners. Boolean true sets a bare attribute; null/false/undefined skip it. +export function el(tag, props = {}, ...children) { + const node = document.createElement(tag); + for (const [key, value] of Object.entries(props)) { + setProp(node, key, value); + } + append(node, children); + return node; +} + +// setProp applies a single prop. Known DOM properties are set directly; an +// `on*` function becomes an event listener; everything else is an attribute +// (a `true` value renders as a bare boolean attribute). +function setProp(node, key, value) { + if (value == null || value === false) return; + switch (key) { + case "class": + node.className = value; + break; + case "text": + node.textContent = value; + break; + case "value": + node.value = value; + break; + case "checked": + node.checked = Boolean(value); + break; + default: + if (key.startsWith("on") && typeof value === "function") { + node.addEventListener(key.slice(2).toLowerCase(), value); + } else { + node.setAttribute(key, value === true ? "" : value); + } + } +} + +// append flattens arrays and turns primitives into text nodes. +export function append(node, children) { + for (const child of children.flat()) { + if (child == null || child === false) continue; + node.append(child.nodeType ? child : document.createTextNode(String(child))); + } +} + +// clear removes all children of a node. +export function clear(node) { + while (node.firstChild) node.removeChild(node.firstChild); +} diff --git a/pkg/cli/playground-ui/form.js b/pkg/cli/playground-ui/form.js new file mode 100644 index 0000000000..fad12a7221 --- /dev/null +++ b/pkg/cli/playground-ui/form.js @@ -0,0 +1,303 @@ +import { el, clear } from "./dom.js"; +import { fieldKind, orderedInputs, coerceEnum } from "./schema.js"; +import { fileToDataURI, formatBytes } from "./api.js"; +import { mediaNode } from "./media.js"; + +// buildForm renders the Input fields into `container` and returns a handle with +// collect(), which reads the current values on demand. There is no reactive +// state: inputs are built once and queried when the user runs a prediction. +export function buildForm(container, root, inputSchema, value = {}) { + clear(container); + const inputs = orderedInputs(inputSchema); + if (inputs.length === 0) { + container.append(el("p", { class: "muted", text: "This model takes no inputs." })); + return { collect: () => ({}) }; + } + + const fields = []; + for (const { name, prop, required } of inputs) { + const field = buildField(root, name, prop, required, value[name]); + container.append(field.element); + fields.push({ name, read: field.read, included: field.included }); + } + + return { + // collect includes required fields always and optional fields only when + // their include checkbox is ticked (ticking happens automatically when the + // field is edited). + collect() { + const out = {}; + for (const { name, included, read } of fields) { + if (included()) out[name] = read(); + } + return out; + }, + }; +} + +// buildField renders one labelled field and returns its value reader plus an +// `included` predicate. Optional fields get an include checkbox so they can be +// omitted from the request; it auto-ticks when the field is edited. +function buildField(root, name, prop, required, initial) { + const kind = fieldKind(root, prop); + const widget = buildWidget(root, kind, initial); + + const label = el("label"); + let includeBox = null; + if (!required) { + includeBox = el("input", { + type: "checkbox", + class: "include-box", + checked: initial !== undefined, + }); + label.append(includeBox); + const touch = () => { + includeBox.checked = true; + }; + widget.element.addEventListener("input", touch); + widget.element.addEventListener("change", touch); + } + label.append(name); + if (required) label.append(el("span", { class: "req", text: " *" })); + if (kind.prop.deprecated) { + label.append(el("span", { class: "deprecated-tag", text: " (deprecated)" })); + } + + const field = el("div", { class: "field" }, label); + if (kind.prop.description) { + field.append(el("small", { class: "desc", text: kind.prop.description })); + } + const hint = constraintText(kind.prop); + if (hint) field.append(el("small", { class: "constraint", text: hint })); + field.append(widget.element); + + return { + element: field, + read: widget.read, + included: () => required || includeBox.checked, + }; +} + +// constraintText summarizes the numeric/string constraints emitted in the +// schema (minimum/maximum, minLength/maxLength, pattern) for display. +function constraintText(prop) { + const parts = []; + if (prop.minimum !== undefined && prop.maximum !== undefined) { + parts.push(`${prop.minimum}–${prop.maximum}`); + } else if (prop.minimum !== undefined) { + parts.push(`min ${prop.minimum}`); + } else if (prop.maximum !== undefined) { + parts.push(`max ${prop.maximum}`); + } + if (prop.minLength !== undefined && prop.maxLength !== undefined) { + parts.push(`${prop.minLength}–${prop.maxLength} chars`); + } else if (prop.minLength !== undefined) { + parts.push(`min ${prop.minLength} chars`); + } else if (prop.maxLength !== undefined) { + parts.push(`max ${prop.maxLength} chars`); + } + if (prop.pattern) parts.push(`pattern: ${prop.pattern}`); + return parts.join(" · "); +} + +// buildWidget maps a field kind to a DOM widget + value reader. Reused for both +// top-level fields and array items. +function buildWidget(root, kind, initial) { + switch (kind.kind) { + case "enum": + return enumWidget(kind.choices, kind.prop, initial); + case "file": + return fileWidget(initial); + case "secret": + return textWidget("password", initial ?? kind.prop.default); + case "string": + return textareaWidget(initial ?? kind.prop.default); + case "integer": + return numberWidget(kind.prop, true, initial); + case "number": + return numberWidget(kind.prop, false, initial); + case "boolean": + return booleanWidget(initial ?? kind.prop.default); + case "array": + return arrayWidget(root, kind.items, initial); + default: + return objectWidget(kind.prop, initial); + } +} + +function textWidget(type, initial) { + const input = el("input", { type, value: initial ?? "" }); + return { element: input, read: () => input.value }; +} + +function textareaWidget(initial) { + const input = el("textarea", { rows: "2", value: initial ?? "" }); + return { element: input, read: () => input.value }; +} + +function numberWidget(prop, isInt, initial) { + const input = el("input", { + type: "number", + value: initial ?? prop.default ?? "", + min: prop.minimum, + max: prop.maximum, + step: isInt ? "1" : "any", + }); + return { + element: input, + read: () => { + if (input.value === "") return ""; + return isInt ? parseInt(input.value, 10) : parseFloat(input.value); + }, + }; +} + +function booleanWidget(initial) { + const input = el("input", { type: "checkbox", checked: initial === true }); + return { element: input, read: () => input.checked }; +} + +function enumWidget(choices, prop, initial) { + const current = initial ?? prop.default; + const select = el("select"); + if (current === undefined || current === null) { + select.append(el("option", { value: "", text: "— select —" })); + } + for (const choice of choices) { + const option = el("option", { value: String(choice), text: String(choice) }); + if (choice === current) option.selected = true; + select.append(option); + } + return { element: select, read: () => coerceEnum(choices, select.value) }; +} + +// fileWidget: upload a file (-> data: URI) OR paste a URL. Mutually exclusive; +// reads as a single string value that round-trips into the JSON editor. Shows +// an inline preview for image/audio/video so you can confirm the input. +function fileWidget(initial) { + let currentValue = typeof initial === "string" ? initial : ""; + + const fileInput = el("input", { type: "file" }); + const fileName = el("span", { class: "file-name" }); + const urlInput = el("input", { + type: "text", + class: "url-input", + placeholder: "https://...", + value: currentValue, + }); + const preview = el("div", { class: "input-preview" }); + + function updatePreview() { + clear(preview); + const node = mediaNode(currentValue); + if (node) preview.append(node); + } + + fileInput.addEventListener("change", async () => { + const file = fileInput.files[0]; + if (!file) return; + currentValue = await fileToDataURI(file); + urlInput.value = ""; + fileName.textContent = `${file.name} (${formatBytes(file.size)})`; + updatePreview(); + }); + + urlInput.addEventListener("input", () => { + currentValue = urlInput.value; + fileInput.value = ""; + fileName.textContent = ""; + updatePreview(); + }); + + const controls = el( + "div", + { class: "file-widget" }, + fileInput, + fileName, + el("span", { class: "muted", text: "or URL" }), + urlInput, + ); + const element = el("div", {}, controls, preview); + updatePreview(); + return { element, read: () => currentValue }; +} + +function objectWidget(prop, initial) { + const text = + initial === undefined + ? prop.default !== undefined + ? JSON.stringify(prop.default, null, 2) + : "" + : typeof initial === "string" + ? initial + : JSON.stringify(initial, null, 2); + + const textarea = el("textarea", { rows: "3", class: "mono", value: text }); + const error = el("small", { class: "field-error" }); + const element = el("div", {}, textarea, error); + + return { + element, + read: () => { + const raw = textarea.value.trim(); + if (raw === "") { + error.textContent = ""; + return ""; + } + try { + const parsed = JSON.parse(raw); + error.textContent = ""; + return parsed; + } catch (err) { + error.textContent = "Invalid JSON: " + err.message; + return ""; + } + }, + }; +} + +// arrayWidget renders a growable list of item widgets. +function arrayWidget(root, items, initial) { + const rows = el("div"); + const itemKind = fieldKind(root, items); + const readers = []; + + function addRow(value) { + const widget = buildWidget(root, itemKind, value); + const reader = widget.read; + readers.push(reader); + + const remove = el("button", { + type: "button", + class: "ghost-btn danger", + text: "−", + onclick: () => { + row.remove(); + const idx = readers.indexOf(reader); + if (idx >= 0) readers.splice(idx, 1); + }, + }); + const row = el("div", { class: "array-row" }, widget.element, remove); + rows.append(row); + } + + const addBtn = el("button", { + type: "button", + class: "ghost-btn", + text: "+ Add", + onclick: () => addRow(undefined), + }); + + const initialItems = Array.isArray(initial) ? initial : []; + for (const v of initialItems) addRow(v); + if (readers.length === 0) addRow(undefined); + + const element = el("div", { class: "array-input" }, rows, addBtn); + return { + element, + read: () => + readers + .map((r) => r()) + .filter((v) => v !== "" && v !== null && v !== undefined), + }; +} diff --git a/pkg/cli/playground-ui/index.html b/pkg/cli/playground-ui/index.html new file mode 100644 index 0000000000..5f9bcec627 --- /dev/null +++ b/pkg/cli/playground-ui/index.html @@ -0,0 +1,655 @@ + + + + + + Cog Playground + + + + +
+

Cog Playground

+ unknown + + + + Schema +
+ +
+ + + + +
+ + + +
+
+
+

Input

+
+ + +
+
+ +
+ +
+ + + +
+
+ + + +
+ +
+ + + +
+ + + + +
+
+ +
+
+

Output

+ +
+ +
+
+
+
+
+
+ + + + diff --git a/pkg/cli/playground-ui/media.js b/pkg/cli/playground-ui/media.js new file mode 100644 index 0000000000..5ecba5ec03 --- /dev/null +++ b/pkg/cli/playground-ui/media.js @@ -0,0 +1,24 @@ +import { el } from "./dom.js"; + +// Recognized media extensions for plain URLs (data: URIs carry their own MIME). +const IMAGE_EXT = /\.(?:png|jpe?g|gif|webp|avif|bmp|svg)(?:[?#]|$)/i; +const AUDIO_EXT = /\.(?:mp3|wav|ogg|oga|flac|m4a|aac|opus)(?:[?#]|$)/i; +const VIDEO_EXT = /\.(?:mp4|webm|ogv|mov|m4v)(?:[?#]|$)/i; + +// mediaNode returns an /