Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions backend/llm_ops/collection_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
price_role_for_source,
update_aggregated_model_identity,
)
from .source_collectors import registered_official_provider_codes
from .skill_runner import (
run_vendor_pricing_skill,
standard_catalog_run_metadata,
Expand Down Expand Up @@ -628,7 +629,7 @@ def sync_configured_official_model_prices(

def supported_official_provider_options() -> list[dict]:
"""Return official provider source presets available to operators."""
provider_codes = list(SUPPORTED_OFFICIAL_PRICE_SYNC_PROVIDER_CODES)
provider_codes = list(registered_official_provider_codes())
providers = {
provider.code: provider
for provider in LLMProvider.objects.filter(code__in=provider_codes)
Expand Down Expand Up @@ -679,7 +680,7 @@ def ensure_supported_official_provider_source(
) -> tuple[LLMProvider, PriceCollectionSource, bool, bool]:
"""Ensure an operator-selected official provider source exists."""
provider_code = str(provider_code or "").strip()
if provider_code not in SUPPORTED_OFFICIAL_PRICE_SYNC_PROVIDER_CODES:
if provider_code not in registered_official_provider_codes():
raise ValueError("Unsupported official provider source.")

defaults = official_provider_defaults(provider_code)
Expand Down
12 changes: 8 additions & 4 deletions backend/llm_ops/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional

from .models import LLMOpsGlobalConfig, PriceCollectionSource
from .collectors.official import OFFICIAL_PROVIDER_CONFIGS
from .source_collectors import registered_official_provider_codes

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +19,11 @@
MODEL_PRICE_SYNC_AGENT_TASK = "llm_ops.tasks.run_model_price_sync_agent"
PRICE_SOURCE_TASK_PREFIX = "llm_ops_price_source_collect_"
PRICE_SOURCE_TASK = "llm_ops.tasks.collect_price_source_prices"
SUPPORTED_PRICE_SYNC_PROVIDER_CODES = tuple(sorted(OFFICIAL_PROVIDER_CONFIGS))


def supported_price_sync_provider_codes() -> tuple[str, ...]:
"""Return provider codes currently supported by runtime sync."""
return registered_official_provider_codes()


def price_source_task_name(source_id: int) -> str:
Expand Down Expand Up @@ -172,7 +176,7 @@ def price_sync_task_source_ids(config: LLMOpsGlobalConfig) -> list[int] | None:
def price_sync_source_queryset():
"""Return sources currently supported by runtime price sync."""
return PriceCollectionSource.objects.filter(
provider__code__in=SUPPORTED_PRICE_SYNC_PROVIDER_CODES,
provider__code__in=supported_price_sync_provider_codes(),
slug__in=official_provider_source_slugs(),
source_category=(
PriceCollectionSource.SOURCE_CATEGORY_OFFICIAL_PROVIDER
Expand All @@ -185,7 +189,7 @@ def official_provider_source_slugs() -> tuple[str, ...]:
"""Return supported provider-level official source slugs."""
return tuple(
f"{provider_code}-official"
for provider_code in SUPPORTED_PRICE_SYNC_PROVIDER_CODES
for provider_code in supported_price_sync_provider_codes()
)


Expand Down
2 changes: 2 additions & 0 deletions backend/llm_ops/source_collectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from .registry import (
collect_price_source,
get_price_source_collector,
registered_official_provider_codes,
source_supports_code_collection,
)

__all__ = [
"collect_price_source",
"get_price_source_collector",
"registered_official_provider_codes",
"source_supports_code_collection",
]
73 changes: 27 additions & 46 deletions backend/llm_ops/source_collectors/official.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
from __future__ import annotations

from llm_ops.collection_services import sync_official_provider_model_prices
from llm_ops.collectors.official import OFFICIAL_PROVIDER_CONFIGS
from llm_ops.models import PriceCollectionSource

from .base import CollectorResult

SUPPORTED_OFFICIAL_PROVIDER_CODES = tuple(
sorted(OFFICIAL_PROVIDER_CONFIGS.keys())
)
OFFICIAL_COLLECTOR_REGISTRY = {}


def register_official_provider_collector(collector_class):
"""Register a backend-implemented official price collector."""
provider_code = getattr(collector_class, "provider_code", "")
if not provider_code:
raise ValueError("Official collector must define provider_code.")
if provider_code not in OFFICIAL_PROVIDER_CONFIGS:
raise ValueError(f"Unknown official provider: {provider_code}")
OFFICIAL_COLLECTOR_REGISTRY[provider_code] = collector_class
return collector_class


def registered_official_provider_codes() -> tuple[str, ...]:
"""Return provider codes with implemented official collectors."""
return tuple(sorted(OFFICIAL_COLLECTOR_REGISTRY))


class OfficialProviderPriceSourceCollector:
Expand Down Expand Up @@ -46,13 +59,18 @@ def collect(
verify_source: bool = True,
) -> CollectorResult:
"""Collect prices using the provider-specific official sync."""
from llm_ops.collection_services import (
sync_official_provider_model_prices,
)

return sync_official_provider_model_prices(
provider=source.provider,
source=source,
verify_source=verify_source,
)


@register_official_provider_collector
class AliyunOfficialPriceSourceCollector(
OfficialProviderPriceSourceCollector,
):
Expand All @@ -61,51 +79,14 @@ class AliyunOfficialPriceSourceCollector(
provider_code = "aliyun"


class AliyunWanxOfficialPriceSourceCollector(
OfficialProviderPriceSourceCollector,
):
"""Collect prices from Aliyun Wanxiang's official pricing source."""

provider_code = "aliyun-wanx"


class BaiduOfficialPriceSourceCollector(
OfficialProviderPriceSourceCollector,
):
"""Collect prices from Baidu Qianfan's official pricing source."""

provider_code = "baidu"


class VolcengineOfficialPriceSourceCollector(
OfficialProviderPriceSourceCollector,
):
"""Collect prices from VolcEngine Ark's official pricing source."""

provider_code = "volcengine"


OFFICIAL_COLLECTOR_CLASSES = {
"aliyun": AliyunOfficialPriceSourceCollector,
"aliyun-wanx": AliyunWanxOfficialPriceSourceCollector,
"baidu": BaiduOfficialPriceSourceCollector,
"volcengine": VolcengineOfficialPriceSourceCollector,
}


def build_official_provider_collectors() -> tuple[
OfficialProviderPriceSourceCollector,
...,
]:
"""Build one collector per supported official provider."""
collectors = []
for provider_code in SUPPORTED_OFFICIAL_PROVIDER_CODES:
collector_class = OFFICIAL_COLLECTOR_CLASSES.get(
provider_code,
OfficialProviderPriceSourceCollector,
return tuple(
collector_class()
for _provider_code, collector_class in sorted(
OFFICIAL_COLLECTOR_REGISTRY.items()
)
if collector_class is OfficialProviderPriceSourceCollector:
collectors.append(collector_class(provider_code=provider_code))
else:
collectors.append(collector_class())
return tuple(collectors)
)
5 changes: 4 additions & 1 deletion backend/llm_ops/source_collectors/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from llm_ops.models import PriceCollectionSource

from .base import CollectorResult, PriceSourceCollector
from .official import build_official_provider_collectors
from .official import (
build_official_provider_collectors,
registered_official_provider_codes,
)


PRICE_SOURCE_COLLECTORS: tuple[PriceSourceCollector, ...] = (
Expand Down
3 changes: 1 addition & 2 deletions backend/llm_ops/tests/test_ops_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,7 @@ def test_supported_official_provider_options_do_not_create_rows(self):
options = supported_official_provider_options()

provider_codes = {option["provider_code"] for option in options}
self.assertIn("aliyun", provider_codes)
self.assertIn("baidu", provider_codes)
self.assertEqual(provider_codes, {"aliyun"})
self.assertFalse(LLMProvider.objects.exists())
self.assertFalse(PriceCollectionSource.objects.exists())

Expand Down
31 changes: 13 additions & 18 deletions backend/llm_ops/tests/test_source_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
from llm_ops.source_collectors import (
collect_price_source,
get_price_source_collector,
registered_official_provider_codes,
source_supports_code_collection,
)
from llm_ops.source_collectors.official import (
SUPPORTED_OFFICIAL_PROVIDER_CODES,
)


class PriceSourceCollectorRegistryTests(TestCase):
def test_all_official_configs_have_registered_collectors(self):
self.assertEqual(
set(OFFICIAL_PROVIDER_CONFIGS),
set(SUPPORTED_OFFICIAL_PROVIDER_CODES),
)
def test_registered_official_collectors_have_provider_configs(self):
provider_codes = set(registered_official_provider_codes())

self.assertEqual(provider_codes, {"aliyun"})
self.assertLessEqual(provider_codes, set(OFFICIAL_PROVIDER_CONFIGS))

def test_official_provider_source_dispatches_to_provider_collector(self):
provider = LLMProvider.objects.create(name="阿里云", code="aliyun")
Expand Down Expand Up @@ -48,8 +46,7 @@ def test_official_provider_source_dispatches_to_provider_collector(self):
self.assertTrue(source_supports_code_collection(source))

with mock.patch(
"llm_ops.source_collectors.official."
"sync_official_provider_model_prices"
"llm_ops.collection_services.sync_official_provider_model_prices"
) as mock_sync:
mock_sync.return_value = {"models": 1}
result = collect_price_source(
Expand All @@ -64,7 +61,7 @@ def test_official_provider_source_dispatches_to_provider_collector(self):
)
self.assertEqual(result, {"models": 1})

def test_generic_official_provider_source_dispatches_to_collector(self):
def test_unregistered_official_provider_source_is_not_supported(self):
provider = LLMProvider.objects.create(name="DeepSeek", code="deepseek")
source = PriceCollectionSource.objects.create(
name="DeepSeek Official",
Expand All @@ -83,11 +80,10 @@ def test_generic_official_provider_source_dispatches_to_collector(self):

collector = get_price_source_collector(source)

self.assertIsNotNone(collector)
self.assertEqual(collector.collector_id, "official_provider:deepseek")
self.assertTrue(source_supports_code_collection(source))
self.assertIsNone(collector)
self.assertFalse(source_supports_code_collection(source))

def test_baidu_official_provider_source_dispatches_to_collector(self):
def test_baidu_official_provider_source_is_not_supported(self):
provider = LLMProvider.objects.create(name="百度", code="baidu")
source = PriceCollectionSource.objects.create(
name="Baidu Official",
Expand All @@ -104,9 +100,8 @@ def test_baidu_official_provider_source_dispatches_to_collector(self):

collector = get_price_source_collector(source)

self.assertIsNotNone(collector)
self.assertEqual(collector.collector_id, "official_provider:baidu")
self.assertTrue(source_supports_code_collection(source))
self.assertIsNone(collector)
self.assertFalse(source_supports_code_collection(source))

def test_model_level_official_source_is_not_supported(self):
provider = LLMProvider.objects.create(name="阿里云", code="aliyun")
Expand Down
36 changes: 26 additions & 10 deletions backend/llm_ops/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def test_collection_sources_sync_all_submits_supported_sources(
updates_model_prices=True,
)
openai = LLMProvider.objects.create(name="OpenAI", code="openai")
openai_source = PriceCollectionSource.objects.create(
PriceCollectionSource.objects.create(
name="OpenAI Official",
slug="openai-official",
provider=openai,
Expand All @@ -967,13 +967,13 @@ def test_collection_sources_sync_all_submits_supported_sources(

self.assertEqual(response.status_code, 202)
self.assertEqual(response.data["task_id"], "task-all")
self.assertEqual(response.data["source_count"], 2)
self.assertEqual(response.data["source_count"], 1)
self.assertEqual(
response.data["source_ids"],
[aliyun_source.id, openai_source.id],
[aliyun_source.id],
)
mock_delay.assert_called_once_with(
source_ids=[aliyun_source.id, openai_source.id],
source_ids=[aliyun_source.id],
verify_source=True,
)
audit = AuditLog.objects.get(
Expand All @@ -983,7 +983,7 @@ def test_collection_sources_sync_all_submits_supported_sources(
self.assertEqual(audit.metadata["task_id"], "task-all")
self.assertEqual(
audit.metadata["source_ids"],
[aliyun_source.id, openai_source.id],
[aliyun_source.id],
)

@patch("llm_ops.views.run_model_price_sync_agent.delay")
Expand All @@ -996,7 +996,7 @@ def test_collection_sources_sync_all_rejects_without_sources(
self.assertEqual(response.status_code, 400)
mock_delay.assert_not_called()

def test_official_provider_options_include_aliyun_presets(self):
def test_official_provider_options_only_include_aliyun_preset(self):
response = self.client.get(
reverse("collection-source-official-provider-options")
)
Expand All @@ -1005,10 +1005,7 @@ def test_official_provider_options_include_aliyun_presets(self):
provider_codes = {
item["provider_code"] for item in response.data["results"]
}
self.assertIn("aliyun", provider_codes)
self.assertIn("aliyun-wanx", provider_codes)
self.assertIn("baidu", provider_codes)
self.assertIn("volcengine", provider_codes)
self.assertEqual(provider_codes, {"aliyun"})

def test_ensure_official_provider_source_creates_aliyun_source(self):
response = self.client.post(
Expand Down Expand Up @@ -1075,6 +1072,25 @@ def test_ensure_official_provider_source_is_idempotent(self):
1,
)

def test_ensure_official_provider_source_rejects_non_aliyun(self):
response = self.client.post(
reverse("collection-source-ensure-official-provider"),
{"provider_code": "baidu"},
format="json",
)

self.assertEqual(response.status_code, 400)
self.assertEqual(
response.data["detail"],
"Unsupported official provider source.",
)
self.assertFalse(LLMProvider.objects.filter(code="baidu").exists())
self.assertFalse(
PriceCollectionSource.objects.filter(
slug="baidu-official",
).exists()
)

@patch("llm_ops.views.run_model_price_sync_agent.delay")
def test_collection_source_collect_rejects_disabled_source(
self,
Expand Down
Loading