Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
Support for OpenAI Realtime API, LLM, TTS, and STT APIs.

Also includes support for a large number of OpenAI-compatible APIs including Azure OpenAI, Cerebras,
Fireworks, Perplexity, Telnyx, xAI, Ollama, DeepSeek, OpenRouter, and OVHcloud AI Endpoints.
Fireworks, Perplexity, Telnyx, xAI, Ollama, DeepSeek, OpenRouter, Cloudflare AI Gateway, and
OVHcloud AI Endpoints.

See https://docs.livekit.io/agents/integrations/openai/ and
https://docs.livekit.io/agents/integrations/llm/ for more information.
Expand All @@ -27,6 +28,8 @@
from .embeddings import EmbeddingData, create_embeddings
from .llm import LLM, LLMStream
from .models import (
CloudflareCustomCost,
CloudflareGatewayOptions,
OpenRouterProviderPreferences,
OpenRouterWebPlugin,
STTModels,
Expand All @@ -42,6 +45,8 @@
"TTS",
"LLM",
"LLMStream",
"CloudflareCustomCost",
"CloudflareGatewayOptions",
"OpenRouterProviderPreferences",
"OpenRouterWebPlugin",
"STTModels",
Expand Down
121 changes: 121 additions & 0 deletions livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import json
import os
from dataclasses import asdict, dataclass
from typing import Any, Literal
Expand Down Expand Up @@ -42,6 +43,7 @@
from .models import (
CerebrasChatModels,
ChatModels,
CloudflareGatewayOptions,
CometAPIChatModels,
DeepSeekChatModels,
NebiusChatModels,
Expand Down Expand Up @@ -938,6 +940,125 @@ def with_letta(
tool_choice=NOT_GIVEN,
)

@staticmethod
def with_cloudflare(
*,
model: str,
account_id: str | None = None,
gateway_id: str = "default",
base_url: str | None = None,
api_key: str | None = None,
cf_aig_token: str | None = None,
gateway_options: CloudflareGatewayOptions | None = None,
client: openai.AsyncClient | None = None,
user: NotGivenOr[str] = NOT_GIVEN,
temperature: NotGivenOr[float] = NOT_GIVEN,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: ToolChoice = "auto",
reasoning_effort: NotGivenOr[ReasoningEffort] = NOT_GIVEN,
safety_identifier: NotGivenOr[str] = NOT_GIVEN,
prompt_cache_key: NotGivenOr[str] = NOT_GIVEN,
top_p: NotGivenOr[float] = NOT_GIVEN,
timeout: httpx.Timeout | None = None,
) -> LLM:
"""
Create a new instance of an LLM backed by the Cloudflare AI Gateway.

The gateway exposes a unified OpenAI-compatible endpoint. The endpoint URL is built
from ``account_id`` and ``gateway_id`` unless an explicit ``base_url`` is given, and the
model is a ``provider/model`` string.

Args:
model (str): Model in ``provider/model`` form, e.g.
``"workers-ai/@cf/meta/llama-3.3-70b-instruct-fp8-fast"`` or ``"openai/gpt-4o"``.
account_id (str | None, optional): Cloudflare account ID used to build the gateway
URL. Falls back to ``CLOUDFLARE_ACCOUNT_ID``. Required unless ``base_url`` is set.
gateway_id (str): Gateway name used to build the URL. Defaults to ``"default"``, which
Cloudflare creates automatically on first request.
base_url (str | None, optional): Full gateway endpoint, e.g.
``"https://gateway.ai.cloudflare.com/v1/<account_id>/<gateway_id>/compat"``.
Overrides ``account_id`` / ``gateway_id`` when provided.
api_key (str | None, optional): Downstream provider key for "bring your own key"
mode, sent as the ``Authorization`` header. Falls back to ``CLOUDFLARE_API_KEY``.
cf_aig_token (str | None, optional): Gateway token sent as the
``cf-aig-authorization`` header, required for authenticated gateways. Falls
back to ``CLOUDFLARE_AI_GATEWAY_TOKEN``.
gateway_options (CloudflareGatewayOptions | None, optional): Per-request gateway
options (caching, retries, timeout, metadata, custom cost), translated into
``cf-aig-*`` request headers.

Returns:
LLM: A configured LLM instance routed through the Cloudflare AI Gateway.
"""

if base_url is None:
account_id = account_id or os.environ.get("CLOUDFLARE_ACCOUNT_ID")
if account_id is None:
raise ValueError(
"Cloudflare account_id is required, either as argument or set "
"CLOUDFLARE_ACCOUNT_ID environment variable (or pass base_url directly)"
)
base_url = f"https://gateway.ai.cloudflare.com/v1/{account_id}/{gateway_id}/compat"

parsed = urlparse(base_url)
if parsed.scheme not in {"http", "https"}:
raise ValueError(f"Invalid URL scheme: '{parsed.scheme}'. Must be 'http' or 'https'.")
if not parsed.netloc:
raise ValueError(f"URL '{base_url}' is missing a network location (e.g., domain name).")

cf_aig_token = cf_aig_token or os.environ.get("CLOUDFLARE_AI_GATEWAY_TOKEN")
api_key = api_key or os.environ.get("CLOUDFLARE_API_KEY")
if cf_aig_token is None and api_key is None:
raise ValueError(
"Cloudflare authentication is required: set api_key/CLOUDFLARE_API_KEY "
"(bring your own provider key) and/or cf_aig_token/CLOUDFLARE_AI_GATEWAY_TOKEN "
"(gateway token)"
)

default_headers: dict[str, str] = {}
if cf_aig_token:
default_headers["cf-aig-authorization"] = f"Bearer {cf_aig_token}"

if gateway_options:
if "cache_ttl" in gateway_options:
default_headers["cf-aig-cache-ttl"] = str(gateway_options["cache_ttl"])
if gateway_options.get("skip_cache"):
default_headers["cf-aig-skip-cache"] = "true"
if "cache_key" in gateway_options:
default_headers["cf-aig-cache-key"] = gateway_options["cache_key"]
if "request_timeout" in gateway_options:
default_headers["cf-aig-request-timeout"] = str(gateway_options["request_timeout"])
if "max_attempts" in gateway_options:
default_headers["cf-aig-max-attempts"] = str(gateway_options["max_attempts"])
if "retry_delay" in gateway_options:
default_headers["cf-aig-retry-delay"] = str(gateway_options["retry_delay"])
if "backoff" in gateway_options:
default_headers["cf-aig-backoff"] = gateway_options["backoff"]
if "metadata" in gateway_options:
default_headers["cf-aig-metadata"] = json.dumps(gateway_options["metadata"])
if "custom_cost" in gateway_options:
default_headers["cf-aig-custom-cost"] = json.dumps(gateway_options["custom_cost"])

return LLM(
model=model,
# The OpenAI SDK requires a non-empty api_key for the Authorization header; in
# gateway-stored-keys mode the real auth rides on cf-aig-authorization, so a
# placeholder is used (matches the with_ollama precedent).
api_key=api_key or "cloudflare",
base_url=base_url,
client=client,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
reasoning_effort=reasoning_effort,
safety_identifier=safety_identifier,
prompt_cache_key=prompt_cache_key,
top_p=top_p,
extra_headers=default_headers,
timeout=timeout,
)

def chat(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,36 @@ class OpenRouterProviderPreferences(TypedDict, total=False):
quantizations: list[str]
sort: Literal["price", "throughput", "latency"]
max_price: dict[str, float]


class CloudflareCustomCost(TypedDict):
"""Custom per-token cost reported to the Cloudflare AI Gateway dashboard."""

per_token_in: float
per_token_out: float


class CloudflareGatewayOptions(TypedDict, total=False):
"""Per-request Cloudflare AI Gateway options, mapped to ``cf-aig-*`` request headers.

See https://developers.cloudflare.com/ai-gateway/configuration/ for details.
"""

cache_ttl: int
"""Cache duration in seconds (``cf-aig-cache-ttl``)."""
skip_cache: bool
"""Bypass the cache for this request (``cf-aig-skip-cache``)."""
cache_key: str
"""Override the default cache key (``cf-aig-cache-key``)."""
request_timeout: int
"""Per-request timeout in milliseconds (``cf-aig-request-timeout``)."""
max_attempts: int
"""Maximum number of request attempts (``cf-aig-max-attempts``)."""
retry_delay: int
"""Delay between retries in milliseconds (``cf-aig-retry-delay``)."""
backoff: Literal["constant", "linear", "exponential"]
"""Retry backoff strategy (``cf-aig-backoff``)."""
metadata: dict[str, str | int | bool]
"""Custom metadata attached to the request (``cf-aig-metadata``)."""
custom_cost: CloudflareCustomCost
"""Custom per-token cost for this request (``cf-aig-custom-cost``)."""
130 changes: 130 additions & 0 deletions tests/test_openai_with_cloudflare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import json

import pytest

from livekit.plugins import openai

pytestmark = pytest.mark.unit

_BASE_URL = "https://gateway.ai.cloudflare.com/v1/acct/default/compat"


@pytest.fixture(autouse=True)
def _clear_cloudflare_env(monkeypatch: pytest.MonkeyPatch) -> None:
# keep construction deterministic regardless of the host environment
monkeypatch.delenv("CLOUDFLARE_API_KEY", raising=False)
monkeypatch.delenv("CLOUDFLARE_AI_GATEWAY_TOKEN", raising=False)
monkeypatch.delenv("CLOUDFLARE_ACCOUNT_ID", raising=False)


def test_builds_url_from_account_and_default_gateway() -> None:
llm = openai.LLM.with_cloudflare(model="openai/gpt-4o", account_id="acct", cf_aig_token="t")
assert (
str(llm._client.base_url).rstrip("/")
== "https://gateway.ai.cloudflare.com/v1/acct/default/compat"
)


def test_builds_url_with_custom_gateway_id() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o", account_id="acct", gateway_id="prod", cf_aig_token="t"
)
assert "/v1/acct/prod/compat" in str(llm._client.base_url)


def test_account_id_falls_back_to_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CLOUDFLARE_ACCOUNT_ID", "env-acct")
llm = openai.LLM.with_cloudflare(model="openai/gpt-4o", cf_aig_token="t")
assert "/v1/env-acct/default/compat" in str(llm._client.base_url)


def test_base_url_overrides_account_id() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o", account_id="ignored", base_url=_BASE_URL, cf_aig_token="t"
)
assert str(llm._client.base_url).rstrip("/") == _BASE_URL


def test_missing_account_id_raises() -> None:
with pytest.raises(ValueError):
openai.LLM.with_cloudflare(model="openai/gpt-4o", cf_aig_token="t")


def test_byok_forwards_provider_key_and_base_url() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o", base_url=_BASE_URL, api_key="sk-provider"
)
assert str(llm._client.base_url).rstrip("/") == _BASE_URL
assert llm._client.api_key == "sk-provider"
# no gateway token -> no cf-aig-authorization header
assert "cf-aig-authorization" not in (llm._opts.extra_headers or {})


def test_gateway_token_sets_header_and_placeholder_key() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o", base_url=_BASE_URL, cf_aig_token="cf-tok"
)
# no provider key -> SDK still needs a non-empty Authorization, so a placeholder is used
assert llm._client.api_key == "cloudflare"
assert llm._opts.extra_headers["cf-aig-authorization"] == "Bearer cf-tok"


def test_gateway_options_map_to_headers() -> None:
llm = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
base_url=_BASE_URL,
cf_aig_token="cf-tok",
gateway_options={
"cache_ttl": 3600,
"cache_key": "k1",
"request_timeout": 2000,
"max_attempts": 3,
"retry_delay": 500,
"backoff": "exponential",
"metadata": {"room": "r1", "turn": 4, "live": True},
"custom_cost": {"per_token_in": 0.000001, "per_token_out": 0.000002},
},
)
headers = llm._opts.extra_headers
assert headers["cf-aig-cache-ttl"] == "3600"
assert headers["cf-aig-cache-key"] == "k1"
assert headers["cf-aig-request-timeout"] == "2000"
assert headers["cf-aig-max-attempts"] == "3"
assert headers["cf-aig-retry-delay"] == "500"
assert headers["cf-aig-backoff"] == "exponential"
# metadata and custom_cost are JSON-encoded
assert json.loads(headers["cf-aig-metadata"]) == {"room": "r1", "turn": 4, "live": True}
assert json.loads(headers["cf-aig-custom-cost"]) == {
"per_token_in": 0.000001,
"per_token_out": 0.000002,
}


def test_skip_cache_header_only_emitted_when_true() -> None:
enabled = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
base_url=_BASE_URL,
cf_aig_token="t",
gateway_options={"skip_cache": True},
)
assert enabled._opts.extra_headers["cf-aig-skip-cache"] == "true"

disabled = openai.LLM.with_cloudflare(
model="openai/gpt-4o",
base_url=_BASE_URL,
cf_aig_token="t",
gateway_options={"skip_cache": False},
)
assert "cf-aig-skip-cache" not in disabled._opts.extra_headers


def test_invalid_base_url_raises() -> None:
with pytest.raises(ValueError):
openai.LLM.with_cloudflare(model="openai/gpt-4o", base_url="not-a-url", cf_aig_token="t")


def test_missing_auth_raises() -> None:
with pytest.raises(ValueError):
openai.LLM.with_cloudflare(model="openai/gpt-4o", base_url=_BASE_URL)