From d0e929b58072b936fea691ff42c011abc2b0a70b Mon Sep 17 00:00:00 2001 From: OpenClaw Agent Date: Fri, 20 Feb 2026 17:09:41 +0000 Subject: [PATCH] feat: implement two-level model pooling and clustering Level 1 Pools: aggregate the same model across multiple providers under a single virtual model ID. Supports priority and round-robin selection strategies, with cross-member failover on 429/503 errors. Level 2 Clusters: group multiple pools (or raw model IDs) under semantic names (e.g., coding-high, coding-fast, chat, reasoning). Changes: - internal/config/config.go: ModelPoolConfig, ModelPool, ModelCluster, PoolMember, ClusterMember types; Config.ModelPools field - internal/pool/resolver.go: Resolver with Resolve, Reload, IsPoolOrCluster, ListVirtualModels; thread-safe hot-reload - internal/pool/resolver_test.go: 11 unit tests covering all scenarios - sdk/api/handlers/handlers.go: PoolResolver field on BaseAPIHandler; VirtualModels() helper; UpdateConfig hook for hot-reload; pool failover wired into Execute*, ExecuteCount*, ExecuteStream* - sdk/api/handlers/pool_execution.go: ExecuteWithPoolFailover and ExecuteStreamWithPoolFailover; isPoolFailoverError; member iteration - sdk/api/handlers/*/: Models() in OpenAI/Claude/Gemini handlers now append virtual pool+cluster models to /v1/models listing - internal/api/handlers/management/model_pools.go: GET/PUT/PATCH/DELETE management API endpoints for pools and clusters - internal/api/server.go: register /v0/management/model-pools routes - config.example.yaml: full documentation with examples --- config.example.yaml | 64 +++++ .../api/handlers/management/model_pools.go | 144 ++++++++++ internal/api/server.go | 8 + internal/config/config.go | 82 ++++++ internal/pool/resolver.go | 255 ++++++++++++++++++ internal/pool/resolver_test.go | 232 ++++++++++++++++ sdk/api/handlers/claude/code_handlers.go | 7 +- sdk/api/handlers/gemini/gemini_handlers.go | 7 +- sdk/api/handlers/handlers.go | 87 +++++- sdk/api/handlers/openai/openai_handlers.go | 9 +- .../openai/openai_responses_handlers.go | 7 +- sdk/api/handlers/pool_execution.go | 243 +++++++++++++++++ 12 files changed, 1134 insertions(+), 11 deletions(-) create mode 100644 internal/api/handlers/management/model_pools.go create mode 100644 internal/pool/resolver.go create mode 100644 internal/pool/resolver_test.go create mode 100644 sdk/api/handlers/pool_execution.go diff --git a/config.example.yaml b/config.example.yaml index 67e416a..b80a788 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -386,3 +386,67 @@ oauth-model-alias: # protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex # params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON) # "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}" + +# Model Pooling & Clustering +# Two-level system for aggregating models across providers under virtual model IDs. +# +# Level 1 — Pools: Route the same logical model from multiple providers under one virtual ID. +# Useful when you have the same model available via different subscriptions/providers. +# The gateway tries members in priority order and fails over on 429/503 errors. +# +# Level 2 — Clusters: Group multiple pools (or raw model IDs) under a semantic name. +# Use this to expose intent-based model IDs like "coding-high", "coding-fast", etc. +# +# Strategies: +# - "round-robin" (default for pools): rotate across members on each request +# - "priority" (default for clusters): try members in order, fail over on exhaustion +# +# model-pools: +# enabled: true +# default-strategy: round-robin +# +# # Level 1: Pools — same model, multiple providers +# pools: +# - id: "claude-sonnet-4" # Virtual model ID exposed to clients +# strategy: "priority" # Override default strategy for this pool +# members: +# - model: "kiro-claude-sonnet-4-5" # Actual model name sent to provider +# provider: "kiro" # Provider type +# priority: 0 # Lower = preferred (used with "priority" strategy) +# weight: 2 # Higher = more requests (used with "round-robin") +# - model: "claude-sonnet-4-20250514" +# provider: "claude" +# priority: 1 +# +# - id: "gpt-4.1" +# members: +# - model: "gpt-4.1" +# provider: "codex" +# +# # Level 2: Clusters — semantic names grouping multiple pools +# clusters: +# - id: "coding-high" # Semantic model ID exposed to clients +# description: "Best available model for complex coding tasks" +# strategy: "priority" +# members: +# - pool: "claude-sonnet-4" # Reference a Level 1 pool by ID +# priority: 0 +# - pool: "gpt-4.1" +# priority: 1 +# +# - id: "coding-fast" +# description: "Fast, cost-efficient model for simple coding tasks" +# members: +# - model: "kiro-claude-haiku-4-5" # Can reference raw models directly +# - model: "gpt-4.1-mini" +# +# - id: "chat" +# description: "General-purpose conversational model" +# members: +# - pool: "claude-sonnet-4" +# +# - id: "reasoning" +# description: "Models optimized for complex reasoning and analysis" +# members: +# - model: "claude-sonnet-4-20250514(64000)" # Supports thinking suffixes +# - model: "gpt-4.1" diff --git a/internal/api/handlers/management/model_pools.go b/internal/api/handlers/management/model_pools.go new file mode 100644 index 0000000..3ed62c1 --- /dev/null +++ b/internal/api/handlers/management/model_pools.go @@ -0,0 +1,144 @@ +package management + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +// GetModelPools returns the current model-pools configuration. +func (h *Handler) GetModelPools(c *gin.Context) { + h.mu.Lock() + defer h.mu.Unlock() + c.JSON(http.StatusOK, h.cfg.ModelPools) +} + +// PutModelPools replaces the entire model-pools configuration. +func (h *Handler) PutModelPools(c *gin.Context) { + var pools config.ModelPoolConfig + if err := c.ShouldBindJSON(&pools); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + h.mu.Lock() + h.cfg.ModelPools = pools + h.mu.Unlock() + + h.persist(c) +} + +// PatchModelPool adds or updates a single pool by ID. +func (h *Handler) PatchModelPool(c *gin.Context) { + var pool config.ModelPool + if err := c.ShouldBindJSON(&pool); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if pool.ID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "pool id is required"}) + return + } + + h.mu.Lock() + found := false + for i, p := range h.cfg.ModelPools.Pools { + if p.ID == pool.ID { + h.cfg.ModelPools.Pools[i] = pool + found = true + break + } + } + if !found { + h.cfg.ModelPools.Pools = append(h.cfg.ModelPools.Pools, pool) + } + h.mu.Unlock() + + h.persist(c) +} + +// DeleteModelPool removes a pool by ID. +func (h *Handler) DeleteModelPool(c *gin.Context) { + id := c.Query("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id query parameter is required"}) + return + } + + h.mu.Lock() + pools := h.cfg.ModelPools.Pools + found := false + for i, p := range pools { + if p.ID == id { + h.cfg.ModelPools.Pools = append(pools[:i], pools[i+1:]...) + found = true + break + } + } + h.mu.Unlock() + + if !found { + c.JSON(http.StatusNotFound, gin.H{"error": "pool not found"}) + return + } + + h.persist(c) +} + +// PatchModelCluster adds or updates a single cluster by ID. +func (h *Handler) PatchModelCluster(c *gin.Context) { + var cluster config.ModelCluster + if err := c.ShouldBindJSON(&cluster); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if cluster.ID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "cluster id is required"}) + return + } + + h.mu.Lock() + found := false + for i, cl := range h.cfg.ModelPools.Clusters { + if cl.ID == cluster.ID { + h.cfg.ModelPools.Clusters[i] = cluster + found = true + break + } + } + if !found { + h.cfg.ModelPools.Clusters = append(h.cfg.ModelPools.Clusters, cluster) + } + h.mu.Unlock() + + h.persist(c) +} + +// DeleteModelCluster removes a cluster by ID. +func (h *Handler) DeleteModelCluster(c *gin.Context) { + id := c.Query("id") + if id == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "id query parameter is required"}) + return + } + + h.mu.Lock() + clusters := h.cfg.ModelPools.Clusters + found := false + for i, cl := range clusters { + if cl.ID == id { + h.cfg.ModelPools.Clusters = append(clusters[:i], clusters[i+1:]...) + found = true + break + } + } + h.mu.Unlock() + + if !found { + c.JSON(http.StatusNotFound, gin.H{"error": "cluster not found"}) + return + } + + h.persist(c) +} diff --git a/internal/api/server.go b/internal/api/server.go index b406a2d..a5b3da6 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -678,6 +678,14 @@ func (s *Server) registerManagementRoutes() { mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys) mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys) + // Model pools and clusters + mgmt.GET("/model-pools", s.mgmt.GetModelPools) + mgmt.PUT("/model-pools", s.mgmt.PutModelPools) + mgmt.PATCH("/model-pools/pool", s.mgmt.PatchModelPool) + mgmt.DELETE("/model-pools/pool", s.mgmt.DeleteModelPool) + mgmt.PATCH("/model-pools/cluster", s.mgmt.PatchModelCluster) + mgmt.DELETE("/model-pools/cluster", s.mgmt.DeleteModelCluster) + mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry) diff --git a/internal/config/config.go b/internal/config/config.go index 65fdfca..a32a54b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -135,6 +135,11 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + // ModelPools defines model pooling and clustering configuration. + // Level 1 pools aggregate the same model across multiple providers under one ID. + // Level 2 clusters group multiple pools under a semantic name (e.g., "coding-high"). + ModelPools ModelPoolConfig `yaml:"model-pools" json:"model-pools"` + // IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode. // This is useful when you want to login with a different account without logging out // from your current session. Default: false. @@ -364,6 +369,83 @@ type PayloadModelRule struct { Protocol string `yaml:"protocol" json:"protocol"` } +// ModelPoolConfig defines model pooling and clustering. +type ModelPoolConfig struct { + // Enabled toggles the pooling feature. When false, pool/cluster model IDs are not recognized. + Enabled bool `yaml:"enabled" json:"enabled"` + + // Pools defines Level 1 pools: same logical model aggregated across providers. + // Each pool exposes a virtual model ID that resolves to concrete models from multiple providers. + Pools []ModelPool `yaml:"pools,omitempty" json:"pools,omitempty"` + + // Clusters defines Level 2 clusters: named groups of pools for semantic routing. + // A cluster ID (e.g., "coding-high") resolves to an ordered list of pool or model IDs. + Clusters []ModelCluster `yaml:"clusters,omitempty" json:"clusters,omitempty"` + + // DefaultStrategy is the selection strategy for pools/clusters when not overridden. + // Supported values: "round-robin" (default), "priority", "latency" (future). + DefaultStrategy string `yaml:"default-strategy,omitempty" json:"default-strategy,omitempty"` +} + +// ModelPool defines a Level 1 pool: a virtual model ID backed by concrete models from one or more providers. +type ModelPool struct { + // ID is the virtual model ID exposed to clients (e.g., "claude-sonnet-4"). + // If empty, the first member's model name is used. + ID string `yaml:"id" json:"id"` + + // Members lists the concrete model+provider pairs that back this pool. + Members []PoolMember `yaml:"members" json:"members"` + + // Strategy overrides the default selection strategy for this pool. + Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` +} + +// PoolMember is a concrete model+provider pair within a pool. +type PoolMember struct { + // Model is the actual model ID as known to the provider (e.g., "kiro-claude-sonnet-4-5"). + Model string `yaml:"model" json:"model"` + + // Provider is the provider type (e.g., "kiro", "claude", "codex"). Optional—if empty, + // all providers that offer this model are included. + Provider string `yaml:"provider,omitempty" json:"provider,omitempty"` + + // Priority is an optional numeric priority (lower = preferred). Only used with "priority" strategy. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` + + // Weight is an optional weight for weighted round-robin. Default is 1. + Weight int `yaml:"weight,omitempty" json:"weight,omitempty"` +} + +// ModelCluster defines a Level 2 cluster: a named group of pools or models. +type ModelCluster struct { + // ID is the cluster name exposed to clients (e.g., "coding-high"). + ID string `yaml:"id" json:"id"` + + // Description is an optional human-readable description. + Description string `yaml:"description,omitempty" json:"description,omitempty"` + + // Members lists pool IDs or model IDs in priority order. + // The first available member is selected. Members can reference pool IDs or raw model IDs. + Members []ClusterMember `yaml:"members" json:"members"` + + // Strategy overrides the default selection strategy for this cluster. + // "priority" (default for clusters): try members in order, use first available. + // "round-robin": rotate across available members. + Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"` +} + +// ClusterMember references a pool or model within a cluster. +type ClusterMember struct { + // Pool references a Level 1 pool ID. Mutually exclusive with Model. + Pool string `yaml:"pool,omitempty" json:"pool,omitempty"` + + // Model references a raw model ID directly (bypasses pool lookup). Mutually exclusive with Pool. + Model string `yaml:"model,omitempty" json:"model,omitempty"` + + // Priority is an optional numeric priority (lower = preferred). Only used with "priority" strategy. + Priority int `yaml:"priority,omitempty" json:"priority,omitempty"` +} + // CloakConfig configures request cloaking for non-Claude-Code clients. // Cloaking disguises API requests to appear as originating from the official Claude Code CLI. type CloakConfig struct { diff --git a/internal/pool/resolver.go b/internal/pool/resolver.go new file mode 100644 index 0000000..7591aa1 --- /dev/null +++ b/internal/pool/resolver.go @@ -0,0 +1,255 @@ +// Package pool implements model pooling and clustering for the AI gateway. +// It provides a two-level resolution system: +// - Level 1 (Pools): aggregate the same model from different providers under one virtual ID. +// - Level 2 (Clusters): group multiple pools/models under a semantic name (e.g., "coding-high"). +// +// The Resolver is built from config and used during request routing to expand +// virtual model IDs into concrete model+provider pairs. +package pool + +import ( + "strings" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +// ResolvedTarget is a concrete model+provider pair produced by pool/cluster resolution. +type ResolvedTarget struct { + // Model is the actual model ID to send to the provider. + Model string + // Providers lists provider types that can serve this model. Empty means "all registered providers". + Providers []string + // Priority is the selection priority (lower = preferred). + Priority int + // Weight for weighted selection. + Weight int +} + +// Resolution is the result of resolving a model ID through the pool/cluster system. +type Resolution struct { + // Matched is true if the model ID was a pool or cluster ID. + Matched bool + // Targets are the concrete model+provider pairs, ordered by priority/config order. + Targets []ResolvedTarget + // Strategy is the selection strategy for this resolution. + Strategy string +} + +// Resolver resolves virtual model IDs (pools/clusters) to concrete model+provider pairs. +type Resolver struct { + mu sync.RWMutex + pools map[string]*config.ModelPool + clusters map[string]*config.ModelCluster + strategy string + enabled bool +} + +// NewResolver creates a Resolver from the given config. +func NewResolver(cfg *config.ModelPoolConfig) *Resolver { + r := &Resolver{ + pools: make(map[string]*config.ModelPool), + clusters: make(map[string]*config.ModelCluster), + } + if cfg == nil { + return r + } + r.enabled = cfg.Enabled + r.strategy = cfg.DefaultStrategy + if r.strategy == "" { + r.strategy = "round-robin" + } + for i := range cfg.Pools { + p := &cfg.Pools[i] + id := strings.TrimSpace(p.ID) + if id == "" && len(p.Members) > 0 { + id = p.Members[0].Model + } + if id != "" { + r.pools[strings.ToLower(id)] = p + } + } + for i := range cfg.Clusters { + c := &cfg.Clusters[i] + id := strings.TrimSpace(c.ID) + if id != "" { + r.clusters[strings.ToLower(id)] = c + } + } + return r +} + +// Reload replaces the resolver's config. Thread-safe. +func (r *Resolver) Reload(cfg *config.ModelPoolConfig) { + fresh := NewResolver(cfg) + r.mu.Lock() + r.pools = fresh.pools + r.clusters = fresh.clusters + r.strategy = fresh.strategy + r.enabled = fresh.enabled + r.mu.Unlock() +} + +// Resolve attempts to resolve a model ID as a pool or cluster. +// Returns a Resolution with Matched=false if the ID is not a pool or cluster. +func (r *Resolver) Resolve(modelID string) Resolution { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.enabled { + return Resolution{} + } + + key := strings.ToLower(strings.TrimSpace(modelID)) + + // Try cluster first (Level 2) + if cluster, ok := r.clusters[key]; ok { + return r.resolveCluster(cluster) + } + + // Try pool (Level 1) + if pool, ok := r.pools[key]; ok { + return r.resolvePool(pool) + } + + return Resolution{} +} + +// IsPoolOrCluster checks if a model ID is a known pool or cluster without full resolution. +func (r *Resolver) IsPoolOrCluster(modelID string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.enabled { + return false + } + key := strings.ToLower(strings.TrimSpace(modelID)) + if _, ok := r.pools[key]; ok { + return true + } + if _, ok := r.clusters[key]; ok { + return true + } + return false +} + +// ListVirtualModels returns all pool and cluster IDs that should appear in model listings. +func (r *Resolver) ListVirtualModels() []VirtualModel { + r.mu.RLock() + defer r.mu.RUnlock() + + if !r.enabled { + return nil + } + + var models []VirtualModel + for _, p := range r.pools { + id := p.ID + if id == "" && len(p.Members) > 0 { + id = p.Members[0].Model + } + if id != "" { + models = append(models, VirtualModel{ + ID: id, + Type: "pool", + }) + } + } + for _, c := range r.clusters { + if c.ID != "" { + models = append(models, VirtualModel{ + ID: c.ID, + Type: "cluster", + Description: c.Description, + }) + } + } + return models +} + +// VirtualModel describes a pool or cluster for model listing purposes. +type VirtualModel struct { + ID string + Type string // "pool" or "cluster" + Description string +} + +func (r *Resolver) resolvePool(pool *config.ModelPool) Resolution { + strategy := pool.Strategy + if strategy == "" { + strategy = r.strategy + } + + var targets []ResolvedTarget + for _, m := range pool.Members { + weight := m.Weight + if weight <= 0 { + weight = 1 + } + providers := r.resolveProviders(m) + targets = append(targets, ResolvedTarget{ + Model: m.Model, + Providers: providers, + Priority: m.Priority, + Weight: weight, + }) + } + + return Resolution{ + Matched: true, + Targets: targets, + Strategy: strategy, + } +} + +func (r *Resolver) resolveCluster(cluster *config.ModelCluster) Resolution { + strategy := cluster.Strategy + if strategy == "" { + strategy = "priority" + } + + var targets []ResolvedTarget + for _, member := range cluster.Members { + if member.Pool != "" { + poolKey := strings.ToLower(strings.TrimSpace(member.Pool)) + if pool, ok := r.pools[poolKey]; ok { + for _, m := range pool.Members { + weight := m.Weight + if weight <= 0 { + weight = 1 + } + providers := r.resolveProviders(m) + targets = append(targets, ResolvedTarget{ + Model: m.Model, + Providers: providers, + Priority: member.Priority + m.Priority, + Weight: weight, + }) + } + } + } else if member.Model != "" { + providers := registry.GetGlobalRegistry().GetModelProviders(member.Model) + targets = append(targets, ResolvedTarget{ + Model: member.Model, + Providers: providers, + Priority: member.Priority, + Weight: 1, + }) + } + } + + return Resolution{ + Matched: true, + Targets: targets, + Strategy: strategy, + } +} + +func (r *Resolver) resolveProviders(m config.PoolMember) []string { + if m.Provider != "" { + return []string{m.Provider} + } + // Look up all providers from the model registry + return registry.GetGlobalRegistry().GetModelProviders(m.Model) +} diff --git a/internal/pool/resolver_test.go b/internal/pool/resolver_test.go new file mode 100644 index 0000000..a813151 --- /dev/null +++ b/internal/pool/resolver_test.go @@ -0,0 +1,232 @@ +package pool + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" +) + +func TestNewResolver_Nil(t *testing.T) { + r := NewResolver(nil) + if r == nil { + t.Fatal("expected non-nil resolver") + } + res := r.Resolve("anything") + if res.Matched { + t.Error("expected no match for nil config") + } +} + +func TestNewResolver_Disabled(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: false, + Pools: []config.ModelPool{ + {ID: "test-pool", Members: []config.PoolMember{{Model: "m1"}}}, + }, + }) + res := r.Resolve("test-pool") + if res.Matched { + t.Error("expected no match when disabled") + } +} + +func TestResolvePool_Basic(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Pools: []config.ModelPool{ + { + ID: "claude-sonnet-4", + Members: []config.PoolMember{ + {Model: "kiro-claude-sonnet-4-5", Provider: "kiro", Priority: 0, Weight: 2}, + {Model: "claude-sonnet-4-20250514", Provider: "claude", Priority: 1}, + }, + Strategy: "priority", + }, + }, + }) + + res := r.Resolve("claude-sonnet-4") + if !res.Matched { + t.Fatal("expected match") + } + if res.Strategy != "priority" { + t.Errorf("expected strategy 'priority', got %q", res.Strategy) + } + if len(res.Targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(res.Targets)) + } + if res.Targets[0].Model != "kiro-claude-sonnet-4-5" { + t.Errorf("expected first target model 'kiro-claude-sonnet-4-5', got %q", res.Targets[0].Model) + } + if res.Targets[0].Weight != 2 { + t.Errorf("expected weight 2, got %d", res.Targets[0].Weight) + } + if res.Targets[1].Priority != 1 { + t.Errorf("expected priority 1, got %d", res.Targets[1].Priority) + } +} + +func TestResolvePool_CaseInsensitive(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Pools: []config.ModelPool{ + {ID: "My-Pool", Members: []config.PoolMember{{Model: "m1"}}}, + }, + }) + res := r.Resolve("my-pool") + if !res.Matched { + t.Error("expected case-insensitive match") + } +} + +func TestResolvePool_DefaultWeight(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Pools: []config.ModelPool{ + {ID: "pool", Members: []config.PoolMember{{Model: "m1"}}}, + }, + }) + res := r.Resolve("pool") + if !res.Matched { + t.Fatal("expected match") + } + if res.Targets[0].Weight != 1 { + t.Errorf("expected default weight 1, got %d", res.Targets[0].Weight) + } +} + +func TestResolveCluster_Basic(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Pools: []config.ModelPool{ + { + ID: "claude-sonnet", + Members: []config.PoolMember{ + {Model: "kiro-claude-sonnet-4-5", Provider: "kiro"}, + }, + }, + { + ID: "gpt-4.1", + Members: []config.PoolMember{ + {Model: "gpt-4.1", Provider: "codex"}, + }, + }, + }, + Clusters: []config.ModelCluster{ + { + ID: "coding-high", + Description: "High-quality coding models", + Members: []config.ClusterMember{ + {Pool: "claude-sonnet", Priority: 0}, + {Pool: "gpt-4.1", Priority: 1}, + }, + }, + }, + }) + + res := r.Resolve("coding-high") + if !res.Matched { + t.Fatal("expected match") + } + if res.Strategy != "priority" { + t.Errorf("expected default cluster strategy 'priority', got %q", res.Strategy) + } + if len(res.Targets) != 2 { + t.Fatalf("expected 2 targets, got %d", len(res.Targets)) + } + if res.Targets[0].Model != "kiro-claude-sonnet-4-5" { + t.Errorf("expected first target 'kiro-claude-sonnet-4-5', got %q", res.Targets[0].Model) + } +} + +func TestResolveCluster_DirectModel(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Clusters: []config.ModelCluster{ + { + ID: "fast", + Members: []config.ClusterMember{ + {Model: "gpt-4.1-mini"}, + }, + }, + }, + }) + res := r.Resolve("fast") + if !res.Matched { + t.Fatal("expected match") + } + if res.Targets[0].Model != "gpt-4.1-mini" { + t.Errorf("unexpected model: %q", res.Targets[0].Model) + } +} + +func TestIsPoolOrCluster(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Pools: []config.ModelPool{{ID: "p1", Members: []config.PoolMember{{Model: "m1"}}}}, + Clusters: []config.ModelCluster{{ID: "c1", Members: []config.ClusterMember{{Pool: "p1"}}}}, + }) + if !r.IsPoolOrCluster("p1") { + t.Error("expected pool match") + } + if !r.IsPoolOrCluster("c1") { + t.Error("expected cluster match") + } + if r.IsPoolOrCluster("unknown") { + t.Error("expected no match for unknown") + } +} + +func TestListVirtualModels(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Pools: []config.ModelPool{{ID: "p1", Members: []config.PoolMember{{Model: "m1"}}}}, + Clusters: []config.ModelCluster{{ID: "c1", Description: "desc", Members: []config.ClusterMember{{Pool: "p1"}}}}, + }) + models := r.ListVirtualModels() + if len(models) != 2 { + t.Fatalf("expected 2 virtual models, got %d", len(models)) + } + foundPool, foundCluster := false, false + for _, m := range models { + if m.ID == "p1" && m.Type == "pool" { + foundPool = true + } + if m.ID == "c1" && m.Type == "cluster" && m.Description == "desc" { + foundCluster = true + } + } + if !foundPool { + t.Error("missing pool in virtual models") + } + if !foundCluster { + t.Error("missing cluster in virtual models") + } +} + +func TestReload(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{Enabled: true, Pools: []config.ModelPool{{ID: "a"}}}) + if !r.IsPoolOrCluster("a") { + t.Fatal("expected 'a'") + } + r.Reload(&config.ModelPoolConfig{Enabled: true, Pools: []config.ModelPool{{ID: "b"}}}) + if r.IsPoolOrCluster("a") { + t.Error("'a' should be gone after reload") + } + if !r.IsPoolOrCluster("b") { + t.Error("expected 'b' after reload") + } +} + +func TestPoolIDFallback(t *testing.T) { + r := NewResolver(&config.ModelPoolConfig{ + Enabled: true, + Pools: []config.ModelPool{ + {Members: []config.PoolMember{{Model: "fallback-model"}}}, + }, + }) + res := r.Resolve("fallback-model") + if !res.Matched { + t.Error("expected match using first member model as ID") + } +} diff --git a/sdk/api/handlers/claude/code_handlers.go b/sdk/api/handlers/claude/code_handlers.go index b05d824..130826e 100644 --- a/sdk/api/handlers/claude/code_handlers.go +++ b/sdk/api/handlers/claude/code_handlers.go @@ -53,7 +53,12 @@ func (h *ClaudeCodeAPIHandler) HandlerType() string { func (h *ClaudeCodeAPIHandler) Models() []map[string]any { // Get dynamic models from the global registry modelRegistry := registry.GetGlobalRegistry() - return modelRegistry.GetAvailableModels("claude") + models := modelRegistry.GetAvailableModels("claude") + // Append virtual pool/cluster models + if vms := h.VirtualModels(); len(vms) > 0 { + models = append(models, vms...) + } + return models } // ClaudeMessages handles Claude-compatible streaming chat completions. diff --git a/sdk/api/handlers/gemini/gemini_handlers.go b/sdk/api/handlers/gemini/gemini_handlers.go index 27d8d1f..521b027 100644 --- a/sdk/api/handlers/gemini/gemini_handlers.go +++ b/sdk/api/handlers/gemini/gemini_handlers.go @@ -42,7 +42,12 @@ func (h *GeminiAPIHandler) HandlerType() string { func (h *GeminiAPIHandler) Models() []map[string]any { // Get dynamic models from the global registry modelRegistry := registry.GetGlobalRegistry() - return modelRegistry.GetAvailableModels("gemini") + models := modelRegistry.GetAvailableModels("gemini") + // Append virtual pool/cluster models + if vms := h.VirtualModels(); len(vms) > 0 { + models = append(models, vms...) + } + return models } // GeminiModels handles the Gemini models listing endpoint. diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index ac02822..5459c34 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -18,6 +18,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/knowledge" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/mcp" + "github.com/router-for-me/CLIProxyAPI/v6/internal/pool" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -189,6 +190,9 @@ type BaseAPIHandler struct { // MCPService handles MCP tool discovery and execution when configured. MCPService *mcp.Service + + // PoolResolver resolves virtual model IDs (pools/clusters) to concrete model+provider pairs. + PoolResolver *pool.Resolver } // NewBaseAPIHandlers creates a new API handlers instance. @@ -202,9 +206,10 @@ type BaseAPIHandler struct { // - *BaseAPIHandler: A new API handlers instance func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler { h := &BaseAPIHandler{ - Cfg: cfg, - AuthManager: authManager, - MCPService: mcp.NewService(nil), + Cfg: cfg, + AuthManager: authManager, + MCPService: mcp.NewService(nil), + PoolResolver: pool.NewResolver(nil), } return h } @@ -216,10 +221,11 @@ func NewBaseAPIHandlersWithConfig(cfg *config.Config, authManager *coreauth.Mana return NewBaseAPIHandlers(nil, authManager) } return &BaseAPIHandler{ - Cfg: &cfg.SDKConfig, - FullCfg: cfg, - AuthManager: authManager, - MCPService: mcp.NewService(cfg), + Cfg: &cfg.SDKConfig, + FullCfg: cfg, + AuthManager: authManager, + MCPService: mcp.NewService(cfg), + PoolResolver: pool.NewResolver(&cfg.ModelPools), } } @@ -248,6 +254,12 @@ func (h *BaseAPIHandler) UpdateConfig(cfg *config.Config) { } else { h.MCPService.UpdateConfig(cfg) } + // Reload pool resolver with updated config + if h.PoolResolver == nil { + h.PoolResolver = pool.NewResolver(&cfg.ModelPools) + } else { + h.PoolResolver.Reload(&cfg.ModelPools) + } } // UpdateKnowledgeManager refreshes the knowledge manager used for profile RAG. @@ -255,6 +267,31 @@ func (h *BaseAPIHandler) UpdateKnowledgeManager(manager *knowledge.Manager) { h.KnowledgeManager = manager } +// VirtualModels returns pool and cluster virtual models in OpenAI model listing format. +// Callers (e.g., OpenAIModels, ClaudeModels) should append these to their model lists. +func (h *BaseAPIHandler) VirtualModels() []map[string]any { + if h.PoolResolver == nil { + return nil + } + vms := h.PoolResolver.ListVirtualModels() + if len(vms) == 0 { + return nil + } + result := make([]map[string]any, 0, len(vms)) + for _, vm := range vms { + m := map[string]any{ + "id": vm.ID, + "object": "model", + "owned_by": "ai-gateway-" + vm.Type, + } + if vm.Description != "" { + m["description"] = vm.Description + } + result = append(result, m) + } + return result +} + // GetAlt extracts the 'alt' parameter from the request query string. // It checks both 'alt' and '$alt' parameters and returns the appropriate value. // @@ -431,6 +468,11 @@ func appendAPIResponse(c *gin.Context, data []byte) { // ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { + // Pool/cluster failover path: try members in priority order before standard routing. + if result, errMsg, handled := h.ExecuteWithPoolFailover(ctx, handlerType, modelName, rawJSON, alt); handled { + return result, errMsg + } + providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { return nil, errMsg @@ -470,6 +512,11 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { + // Pool/cluster failover path: try members in priority order before standard routing. + if result, errMsg, handled := h.ExecuteWithPoolFailover(ctx, handlerType, modelName, rawJSON, alt); handled { + return result, errMsg + } + providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { return nil, errMsg @@ -509,6 +556,11 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { + // Pool/cluster failover path: try members in priority order before standard routing. + if dataChan, errChan, handled := h.ExecuteStreamWithPoolFailover(ctx, handlerType, modelName, rawJSON, alt); handled { + return dataChan, errChan + } + providers, normalizedModel, errMsg := h.getRequestDetails(modelName) if errMsg != nil { errChan := make(chan *interfaces.ErrorMessage, 1) @@ -658,6 +710,27 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string parsed := thinking.ParseSuffix(resolvedModelName) baseModel := strings.TrimSpace(parsed.ModelName) + // Pool/cluster resolution: if the model ID matches a pool or cluster, + // resolve it to the first concrete target's model and providers. + if h.PoolResolver != nil { + resolution := h.PoolResolver.Resolve(baseModel) + if resolution.Matched && len(resolution.Targets) > 0 { + // Use the first target (priority-based or first in config order). + // The conductor's existing multi-provider failover handles retries. + target := resolution.Targets[0] + resolvedModel := target.Model + if parsed.HasSuffix { + resolvedModel = fmt.Sprintf("%s(%s)", target.Model, parsed.RawSuffix) + } + if len(target.Providers) > 0 { + return target.Providers, resolvedModel, nil + } + // Fall through to normal provider lookup with the resolved model name + resolvedModelName = resolvedModel + baseModel = target.Model + } + } + providers = util.GetProviderName(baseModel) // Fallback: if baseModel has no provider but differs from resolvedModelName, // try using the full model name. This handles edge cases where custom models diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go index fadfd44..3e8ea66 100644 --- a/sdk/api/handlers/openai/openai_handlers.go +++ b/sdk/api/handlers/openai/openai_handlers.go @@ -54,7 +54,14 @@ func (h *OpenAIAPIHandler) HandlerType() string { func (h *OpenAIAPIHandler) Models() []map[string]any { // Get dynamic models from the global registry modelRegistry := registry.GetGlobalRegistry() - return modelRegistry.GetAvailableModels("openai") + models := modelRegistry.GetAvailableModels("openai") + + // Append virtual pool/cluster models + if vms := h.VirtualModels(); len(vms) > 0 { + models = append(models, vms...) + } + + return models } // OpenAIModels handles the /v1/models endpoint. diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index 045f7b6..dbaa30e 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -50,7 +50,12 @@ func (h *OpenAIResponsesAPIHandler) HandlerType() string { func (h *OpenAIResponsesAPIHandler) Models() []map[string]any { // Get dynamic models from the global registry modelRegistry := registry.GetGlobalRegistry() - return modelRegistry.GetAvailableModels("openai") + models := modelRegistry.GetAvailableModels("openai") + // Append virtual pool/cluster models + if vms := h.VirtualModels(); len(vms) > 0 { + models = append(models, vms...) + } + return models } // OpenAIResponsesModels handles the /v1/models endpoint. diff --git a/sdk/api/handlers/pool_execution.go b/sdk/api/handlers/pool_execution.go new file mode 100644 index 0000000..4f7e8f5 --- /dev/null +++ b/sdk/api/handlers/pool_execution.go @@ -0,0 +1,243 @@ +package handlers + +import ( + "context" + "fmt" + "net/http" + + coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/pool" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" +) + +// isPoolFailoverError returns true if the status code indicates all credentials for a provider +// are exhausted or cooling down, meaning the next pool member should be tried. +// 429 (rate limited/cooling down) and 503 (service unavailable/quota) are retriable. +func isPoolFailoverError(status int) bool { + switch status { + case http.StatusTooManyRequests, // 429 + http.StatusServiceUnavailable: // 503 + return true + } + return false +} + +// resolvedModelWithSuffix re-applies a thinking suffix to a pool target's model name, +// preserving any suffix the user included in their original request. +func resolvedModelWithSuffix(targetModel string, suffix thinking.SuffixResult) string { + if suffix.HasSuffix { + return fmt.Sprintf("%s(%s)", targetModel, suffix.RawSuffix) + } + return targetModel +} + +// poolOrClusterResolution checks if modelName is a pool/cluster and returns its targets. +// Returns nil + zero SuffixResult if not a pool/cluster or if the resolver is unavailable. +func (h *BaseAPIHandler) poolOrClusterResolution(modelName string) (*pool.Resolution, thinking.SuffixResult) { + if h.PoolResolver == nil { + return nil, thinking.SuffixResult{} + } + suffix := thinking.ParseSuffix(modelName) + res := h.PoolResolver.Resolve(suffix.ModelName) + if !res.Matched || len(res.Targets) == 0 { + return nil, thinking.SuffixResult{} + } + return &res, suffix +} + +// buildPoolExecutionRequest builds the coreexecutor.Request and Options for a given target model, +// preserving the original payload while rewriting the model field. +func buildPoolExecutionRequest(targetModel string, rawJSON []byte, handlerType string, stream bool, alt string) (coreexecutor.Request, coreexecutor.Options) { + req := coreexecutor.Request{ + Model: targetModel, + Payload: cloneBytes(rawJSON), + } + opts := coreexecutor.Options{ + Stream: stream, + Alt: alt, + OriginalRequest: cloneBytes(rawJSON), + SourceFormat: sdktranslator.FromString(handlerType), + } + return req, opts +} + +// resolveTargetProviders returns the providers for a pool target. +// Uses the explicitly configured providers list, falling back to registry lookup. +func resolveTargetProviders(target pool.ResolvedTarget) []string { + if len(target.Providers) > 0 { + return target.Providers + } + return util.GetProviderName(target.Model) +} + +// ExecuteWithPoolFailover executes a non-streaming request with pool-aware failover. +// If modelName resolves to a pool or cluster, members are tried in priority order. +// When a member's providers return a retriable error (429/503), the next member is tried. +// Returns (response, error, handled). handled=false means the model is not a pool/cluster; +// the caller should fall back to the standard Execute path. +func (h *BaseAPIHandler) ExecuteWithPoolFailover(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage, bool) { + resolution, suffix := h.poolOrClusterResolution(modelName) + if resolution == nil { + return nil, nil, false + } + + targets := resolution.Targets + var lastErr *interfaces.ErrorMessage + + for i, target := range targets { + resolvedModel := resolvedModelWithSuffix(target.Model, suffix) + req, opts := buildPoolExecutionRequest(resolvedModel, rawJSON, handlerType, false, alt) + + reqMeta := requestExecutionMetadata(ctx) + reqMeta[coreexecutor.RequestedModelMetadataKey] = resolvedModel + opts.Metadata = reqMeta + + providers := resolveTargetProviders(target) + if len(providers) == 0 { + // Skip: no resolvable providers for this member + continue + } + + resp, err := h.AuthManager.Execute(ctx, providers, req, opts) + if err != nil { + status := statusFromError(err) + var addon http.Header + if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + lastErr = &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + + if isPoolFailoverError(status) && i < len(targets)-1 { + // Retriable error and more members remain: try next + continue + } + // Non-retriable, or last member + return nil, lastErr, true + } + + return cloneBytes(resp.Payload), nil, true + } + + if lastErr != nil { + return nil, lastErr, true + } + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusServiceUnavailable, + Error: fmt.Errorf("all pool members exhausted for model %q", modelName), + }, true +} + +// ExecuteStreamWithPoolFailover executes a streaming request with pool-aware failover. +// When the first pool member is exhausted before any bytes are sent, it transparently +// tries the next member. Returns (dataChan, errChan, handled bool). +// handled=false means the model is not a pool/cluster; caller should use the standard path. +func (h *BaseAPIHandler) ExecuteStreamWithPoolFailover(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage, bool) { + resolution, suffix := h.poolOrClusterResolution(modelName) + if resolution == nil { + return nil, nil, false + } + + targets := resolution.Targets + dataChan := make(chan []byte) + errChan := make(chan *interfaces.ErrorMessage, 1) + + go func() { + defer close(dataChan) + defer close(errChan) + + for i, target := range targets { + resolvedModel := resolvedModelWithSuffix(target.Model, suffix) + req, opts := buildPoolExecutionRequest(resolvedModel, rawJSON, handlerType, true, alt) + + reqMeta := requestExecutionMetadata(ctx) + reqMeta[coreexecutor.RequestedModelMetadataKey] = resolvedModel + opts.Metadata = reqMeta + + providers := resolveTargetProviders(target) + if len(providers) == 0 { + continue + } + + chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + if err != nil { + status := statusFromError(err) + if isPoolFailoverError(status) && i < len(targets)-1 { + continue + } + var addon http.Header + if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } + } + errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + return + } + + sentPayload := false + memberExhausted := false + + streamLoop: + for { + var chunk coreexecutor.StreamChunk + var ok bool + if ctx != nil { + select { + case <-ctx.Done(): + return + case chunk, ok = <-chunks: + } + } else { + chunk, ok = <-chunks + } + if !ok { + // Channel closed cleanly — stream done + return + } + if chunk.Err != nil { + // If no payload sent yet and this is a retriable error with more members, fail over + if !sentPayload { + status := statusFromError(chunk.Err) + if isPoolFailoverError(status) && i < len(targets)-1 { + memberExhausted = true + break streamLoop + } + } + // Payload already sent, or non-retriable error: propagate + streamStatus := statusFromError(chunk.Err) + if streamStatus == 0 { + streamStatus = http.StatusInternalServerError + } + errChan <- &interfaces.ErrorMessage{StatusCode: streamStatus, Error: chunk.Err} + return + } + sentPayload = true + select { + case <-ctx.Done(): + return + case dataChan <- cloneBytes(chunk.Payload): + } + } + + if !memberExhausted { + // Clean exit + return + } + // memberExhausted=true: continue loop to try next target + } + + // All members failed before sending any data + errChan <- &interfaces.ErrorMessage{ + StatusCode: http.StatusServiceUnavailable, + Error: fmt.Errorf("all pool members exhausted for model %q", modelName), + } + }() + + return dataChan, errChan, true +}