diff --git a/backend/llm_ops/collection_services.py b/backend/llm_ops/collection_services.py index 62396fb..359e25d 100644 --- a/backend/llm_ops/collection_services.py +++ b/backend/llm_ops/collection_services.py @@ -37,6 +37,7 @@ price_role_for_source, update_aggregated_model_identity, ) +from .source_collectors import registered_official_provider_codes from .skill_runner import ( run_vendor_pricing_skill, standard_catalog_run_metadata, @@ -628,7 +629,7 @@ def sync_configured_official_model_prices( def supported_official_provider_options() -> list[dict]: """Return official provider source presets available to operators.""" - provider_codes = list(SUPPORTED_OFFICIAL_PRICE_SYNC_PROVIDER_CODES) + provider_codes = list(registered_official_provider_codes()) providers = { provider.code: provider for provider in LLMProvider.objects.filter(code__in=provider_codes) @@ -679,7 +680,7 @@ def ensure_supported_official_provider_source( ) -> tuple[LLMProvider, PriceCollectionSource, bool, bool]: """Ensure an operator-selected official provider source exists.""" provider_code = str(provider_code or "").strip() - if provider_code not in SUPPORTED_OFFICIAL_PRICE_SYNC_PROVIDER_CODES: + if provider_code not in registered_official_provider_codes(): raise ValueError("Unsupported official provider source.") defaults = official_provider_defaults(provider_code) diff --git a/backend/llm_ops/global_config.py b/backend/llm_ops/global_config.py index 893d556..94fa39a 100644 --- a/backend/llm_ops/global_config.py +++ b/backend/llm_ops/global_config.py @@ -7,7 +7,7 @@ from typing import Optional from .models import LLMOpsGlobalConfig, PriceCollectionSource -from .collectors.official import OFFICIAL_PROVIDER_CONFIGS +from .source_collectors import registered_official_provider_codes logger = logging.getLogger(__name__) @@ -19,7 +19,11 @@ MODEL_PRICE_SYNC_AGENT_TASK = "llm_ops.tasks.run_model_price_sync_agent" PRICE_SOURCE_TASK_PREFIX = "llm_ops_price_source_collect_" PRICE_SOURCE_TASK = "llm_ops.tasks.collect_price_source_prices" -SUPPORTED_PRICE_SYNC_PROVIDER_CODES = tuple(sorted(OFFICIAL_PROVIDER_CONFIGS)) + + +def supported_price_sync_provider_codes() -> tuple[str, ...]: + """Return provider codes currently supported by runtime sync.""" + return registered_official_provider_codes() def price_source_task_name(source_id: int) -> str: @@ -172,7 +176,7 @@ def price_sync_task_source_ids(config: LLMOpsGlobalConfig) -> list[int] | None: def price_sync_source_queryset(): """Return sources currently supported by runtime price sync.""" return PriceCollectionSource.objects.filter( - provider__code__in=SUPPORTED_PRICE_SYNC_PROVIDER_CODES, + provider__code__in=supported_price_sync_provider_codes(), slug__in=official_provider_source_slugs(), source_category=( PriceCollectionSource.SOURCE_CATEGORY_OFFICIAL_PROVIDER @@ -185,7 +189,7 @@ def official_provider_source_slugs() -> tuple[str, ...]: """Return supported provider-level official source slugs.""" return tuple( f"{provider_code}-official" - for provider_code in SUPPORTED_PRICE_SYNC_PROVIDER_CODES + for provider_code in supported_price_sync_provider_codes() ) diff --git a/backend/llm_ops/source_collectors/__init__.py b/backend/llm_ops/source_collectors/__init__.py index c934d3c..cd64af1 100644 --- a/backend/llm_ops/source_collectors/__init__.py +++ b/backend/llm_ops/source_collectors/__init__.py @@ -3,11 +3,13 @@ from .registry import ( collect_price_source, get_price_source_collector, + registered_official_provider_codes, source_supports_code_collection, ) __all__ = [ "collect_price_source", "get_price_source_collector", + "registered_official_provider_codes", "source_supports_code_collection", ] diff --git a/backend/llm_ops/source_collectors/official.py b/backend/llm_ops/source_collectors/official.py index 1f16deb..a88b5ad 100644 --- a/backend/llm_ops/source_collectors/official.py +++ b/backend/llm_ops/source_collectors/official.py @@ -1,14 +1,27 @@ from __future__ import annotations -from llm_ops.collection_services import sync_official_provider_model_prices from llm_ops.collectors.official import OFFICIAL_PROVIDER_CONFIGS from llm_ops.models import PriceCollectionSource from .base import CollectorResult -SUPPORTED_OFFICIAL_PROVIDER_CODES = tuple( - sorted(OFFICIAL_PROVIDER_CONFIGS.keys()) -) +OFFICIAL_COLLECTOR_REGISTRY = {} + + +def register_official_provider_collector(collector_class): + """Register a backend-implemented official price collector.""" + provider_code = getattr(collector_class, "provider_code", "") + if not provider_code: + raise ValueError("Official collector must define provider_code.") + if provider_code not in OFFICIAL_PROVIDER_CONFIGS: + raise ValueError(f"Unknown official provider: {provider_code}") + OFFICIAL_COLLECTOR_REGISTRY[provider_code] = collector_class + return collector_class + + +def registered_official_provider_codes() -> tuple[str, ...]: + """Return provider codes with implemented official collectors.""" + return tuple(sorted(OFFICIAL_COLLECTOR_REGISTRY)) class OfficialProviderPriceSourceCollector: @@ -46,6 +59,10 @@ def collect( verify_source: bool = True, ) -> CollectorResult: """Collect prices using the provider-specific official sync.""" + from llm_ops.collection_services import ( + sync_official_provider_model_prices, + ) + return sync_official_provider_model_prices( provider=source.provider, source=source, @@ -53,6 +70,7 @@ def collect( ) +@register_official_provider_collector class AliyunOfficialPriceSourceCollector( OfficialProviderPriceSourceCollector, ): @@ -61,51 +79,14 @@ class AliyunOfficialPriceSourceCollector( provider_code = "aliyun" -class AliyunWanxOfficialPriceSourceCollector( - OfficialProviderPriceSourceCollector, -): - """Collect prices from Aliyun Wanxiang's official pricing source.""" - - provider_code = "aliyun-wanx" - - -class BaiduOfficialPriceSourceCollector( - OfficialProviderPriceSourceCollector, -): - """Collect prices from Baidu Qianfan's official pricing source.""" - - provider_code = "baidu" - - -class VolcengineOfficialPriceSourceCollector( - OfficialProviderPriceSourceCollector, -): - """Collect prices from VolcEngine Ark's official pricing source.""" - - provider_code = "volcengine" - - -OFFICIAL_COLLECTOR_CLASSES = { - "aliyun": AliyunOfficialPriceSourceCollector, - "aliyun-wanx": AliyunWanxOfficialPriceSourceCollector, - "baidu": BaiduOfficialPriceSourceCollector, - "volcengine": VolcengineOfficialPriceSourceCollector, -} - - def build_official_provider_collectors() -> tuple[ OfficialProviderPriceSourceCollector, ..., ]: """Build one collector per supported official provider.""" - collectors = [] - for provider_code in SUPPORTED_OFFICIAL_PROVIDER_CODES: - collector_class = OFFICIAL_COLLECTOR_CLASSES.get( - provider_code, - OfficialProviderPriceSourceCollector, + return tuple( + collector_class() + for _provider_code, collector_class in sorted( + OFFICIAL_COLLECTOR_REGISTRY.items() ) - if collector_class is OfficialProviderPriceSourceCollector: - collectors.append(collector_class(provider_code=provider_code)) - else: - collectors.append(collector_class()) - return tuple(collectors) + ) diff --git a/backend/llm_ops/source_collectors/registry.py b/backend/llm_ops/source_collectors/registry.py index c3951de..aa5e5cb 100644 --- a/backend/llm_ops/source_collectors/registry.py +++ b/backend/llm_ops/source_collectors/registry.py @@ -3,7 +3,10 @@ from llm_ops.models import PriceCollectionSource from .base import CollectorResult, PriceSourceCollector -from .official import build_official_provider_collectors +from .official import ( + build_official_provider_collectors, + registered_official_provider_codes, +) PRICE_SOURCE_COLLECTORS: tuple[PriceSourceCollector, ...] = ( diff --git a/backend/llm_ops/tests/test_ops_bootstrap.py b/backend/llm_ops/tests/test_ops_bootstrap.py index c2cb802..f1d8e64 100644 --- a/backend/llm_ops/tests/test_ops_bootstrap.py +++ b/backend/llm_ops/tests/test_ops_bootstrap.py @@ -267,8 +267,7 @@ def test_supported_official_provider_options_do_not_create_rows(self): options = supported_official_provider_options() provider_codes = {option["provider_code"] for option in options} - self.assertIn("aliyun", provider_codes) - self.assertIn("baidu", provider_codes) + self.assertEqual(provider_codes, {"aliyun"}) self.assertFalse(LLMProvider.objects.exists()) self.assertFalse(PriceCollectionSource.objects.exists()) diff --git a/backend/llm_ops/tests/test_source_collectors.py b/backend/llm_ops/tests/test_source_collectors.py index 9d86d32..d4f1b05 100644 --- a/backend/llm_ops/tests/test_source_collectors.py +++ b/backend/llm_ops/tests/test_source_collectors.py @@ -7,19 +7,17 @@ from llm_ops.source_collectors import ( collect_price_source, get_price_source_collector, + registered_official_provider_codes, source_supports_code_collection, ) -from llm_ops.source_collectors.official import ( - SUPPORTED_OFFICIAL_PROVIDER_CODES, -) class PriceSourceCollectorRegistryTests(TestCase): - def test_all_official_configs_have_registered_collectors(self): - self.assertEqual( - set(OFFICIAL_PROVIDER_CONFIGS), - set(SUPPORTED_OFFICIAL_PROVIDER_CODES), - ) + def test_registered_official_collectors_have_provider_configs(self): + provider_codes = set(registered_official_provider_codes()) + + self.assertEqual(provider_codes, {"aliyun"}) + self.assertLessEqual(provider_codes, set(OFFICIAL_PROVIDER_CONFIGS)) def test_official_provider_source_dispatches_to_provider_collector(self): provider = LLMProvider.objects.create(name="阿里云", code="aliyun") @@ -48,8 +46,7 @@ def test_official_provider_source_dispatches_to_provider_collector(self): self.assertTrue(source_supports_code_collection(source)) with mock.patch( - "llm_ops.source_collectors.official." - "sync_official_provider_model_prices" + "llm_ops.collection_services.sync_official_provider_model_prices" ) as mock_sync: mock_sync.return_value = {"models": 1} result = collect_price_source( @@ -64,7 +61,7 @@ def test_official_provider_source_dispatches_to_provider_collector(self): ) self.assertEqual(result, {"models": 1}) - def test_generic_official_provider_source_dispatches_to_collector(self): + def test_unregistered_official_provider_source_is_not_supported(self): provider = LLMProvider.objects.create(name="DeepSeek", code="deepseek") source = PriceCollectionSource.objects.create( name="DeepSeek Official", @@ -83,11 +80,10 @@ def test_generic_official_provider_source_dispatches_to_collector(self): collector = get_price_source_collector(source) - self.assertIsNotNone(collector) - self.assertEqual(collector.collector_id, "official_provider:deepseek") - self.assertTrue(source_supports_code_collection(source)) + self.assertIsNone(collector) + self.assertFalse(source_supports_code_collection(source)) - def test_baidu_official_provider_source_dispatches_to_collector(self): + def test_baidu_official_provider_source_is_not_supported(self): provider = LLMProvider.objects.create(name="百度", code="baidu") source = PriceCollectionSource.objects.create( name="Baidu Official", @@ -104,9 +100,8 @@ def test_baidu_official_provider_source_dispatches_to_collector(self): collector = get_price_source_collector(source) - self.assertIsNotNone(collector) - self.assertEqual(collector.collector_id, "official_provider:baidu") - self.assertTrue(source_supports_code_collection(source)) + self.assertIsNone(collector) + self.assertFalse(source_supports_code_collection(source)) def test_model_level_official_source_is_not_supported(self): provider = LLMProvider.objects.create(name="阿里云", code="aliyun") diff --git a/backend/llm_ops/tests/test_views.py b/backend/llm_ops/tests/test_views.py index fd66048..3218c50 100644 --- a/backend/llm_ops/tests/test_views.py +++ b/backend/llm_ops/tests/test_views.py @@ -950,7 +950,7 @@ def test_collection_sources_sync_all_submits_supported_sources( updates_model_prices=True, ) openai = LLMProvider.objects.create(name="OpenAI", code="openai") - openai_source = PriceCollectionSource.objects.create( + PriceCollectionSource.objects.create( name="OpenAI Official", slug="openai-official", provider=openai, @@ -967,13 +967,13 @@ def test_collection_sources_sync_all_submits_supported_sources( self.assertEqual(response.status_code, 202) self.assertEqual(response.data["task_id"], "task-all") - self.assertEqual(response.data["source_count"], 2) + self.assertEqual(response.data["source_count"], 1) self.assertEqual( response.data["source_ids"], - [aliyun_source.id, openai_source.id], + [aliyun_source.id], ) mock_delay.assert_called_once_with( - source_ids=[aliyun_source.id, openai_source.id], + source_ids=[aliyun_source.id], verify_source=True, ) audit = AuditLog.objects.get( @@ -983,7 +983,7 @@ def test_collection_sources_sync_all_submits_supported_sources( self.assertEqual(audit.metadata["task_id"], "task-all") self.assertEqual( audit.metadata["source_ids"], - [aliyun_source.id, openai_source.id], + [aliyun_source.id], ) @patch("llm_ops.views.run_model_price_sync_agent.delay") @@ -996,7 +996,7 @@ def test_collection_sources_sync_all_rejects_without_sources( self.assertEqual(response.status_code, 400) mock_delay.assert_not_called() - def test_official_provider_options_include_aliyun_presets(self): + def test_official_provider_options_only_include_aliyun_preset(self): response = self.client.get( reverse("collection-source-official-provider-options") ) @@ -1005,10 +1005,7 @@ def test_official_provider_options_include_aliyun_presets(self): provider_codes = { item["provider_code"] for item in response.data["results"] } - self.assertIn("aliyun", provider_codes) - self.assertIn("aliyun-wanx", provider_codes) - self.assertIn("baidu", provider_codes) - self.assertIn("volcengine", provider_codes) + self.assertEqual(provider_codes, {"aliyun"}) def test_ensure_official_provider_source_creates_aliyun_source(self): response = self.client.post( @@ -1075,6 +1072,25 @@ def test_ensure_official_provider_source_is_idempotent(self): 1, ) + def test_ensure_official_provider_source_rejects_non_aliyun(self): + response = self.client.post( + reverse("collection-source-ensure-official-provider"), + {"provider_code": "baidu"}, + format="json", + ) + + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.data["detail"], + "Unsupported official provider source.", + ) + self.assertFalse(LLMProvider.objects.filter(code="baidu").exists()) + self.assertFalse( + PriceCollectionSource.objects.filter( + slug="baidu-official", + ).exists() + ) + @patch("llm_ops.views.run_model_price_sync_agent.delay") def test_collection_source_collect_rejects_disabled_source( self, diff --git a/frontend/src/components/llm-ops/ChannelModelDrawer.vue b/frontend/src/components/llm-ops/ChannelModelDrawer.vue index 8b886cc..88c5527 100644 --- a/frontend/src/components/llm-ops/ChannelModelDrawer.vue +++ b/frontend/src/components/llm-ops/ChannelModelDrawer.vue @@ -302,6 +302,8 @@ selectedProviderByModelKey[item.group.key] || '' " :options="item.options" + class-name="w-full" + :menu-min-width="260" placeholder="选择渠道上游" searchable search-placeholder="搜索上游 / 币种 / 类型" @@ -314,15 +316,32 @@ 无可用上游
{{ t('llmOps.priceSourceModal.sections.basicHint') }}
+{{ t('llmOps.priceSourceModal.sections.basicHint') }}
+