Skip to content

Commit 20f9e88

Browse files
authored
Merge pull request #153 from rumpl/feat-providers
Handle more providers
2 parents 4572d18 + 723509b commit 20f9e88

4 files changed

Lines changed: 114 additions & 23 deletions

File tree

pkg/model/provider/openai/client.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7+
"fmt"
78
"log/slog"
89
"strings"
910

@@ -35,11 +36,6 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
3536
return nil, errors.New("model configuration is required")
3637
}
3738

38-
if cfg.Provider != "openai" {
39-
slog.Error("OpenAI client creation failed", "error", "model type must be 'openai'", "actual_type", cfg.Provider)
40-
return nil, errors.New("model type must be 'openai'")
41-
}
42-
4339
var globalOptions options.ModelOptions
4440
for _, opt := range opts {
4541
opt(&globalOptions)
@@ -53,13 +49,31 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
5349
}
5450
authToken := env.Get(ctx, key)
5551
if authToken == "" {
56-
return nil, errors.New("OPENAI_API_KEY environment variable is required")
52+
return nil, fmt.Errorf("%s environment variable is required", key)
53+
}
54+
55+
if cfg.Provider == "azure" {
56+
openaiConfig = openai.DefaultAzureConfig(authToken, cfg.BaseURL)
57+
} else {
58+
openaiConfig = openai.DefaultConfig(authToken)
5759
}
5860

59-
openaiConfig = openai.DefaultConfig(authToken)
6061
if cfg.BaseURL != "" {
6162
openaiConfig.BaseURL = cfg.BaseURL
6263
}
64+
65+
// TODO: Move this logic to ProviderAliases as a config function
66+
if cfg.ProviderOpts != nil {
67+
switch cfg.Provider { //nolint:gocritic
68+
case "azure":
69+
if apiVersion, exists := cfg.ProviderOpts["api_version"]; exists {
70+
slog.Debug("Setting API version", "api_version", apiVersion)
71+
if apiVersionStr, ok := apiVersion.(string); ok {
72+
openaiConfig.APIVersion = apiVersionStr
73+
}
74+
}
75+
}
76+
}
6377
} else {
6478
authToken := desktop.GetToken(ctx)
6579
if authToken == "" {

pkg/model/provider/provider.go

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,26 @@ import (
1616
"github.com/docker/cagent/pkg/tools"
1717
)
1818

19+
// Alias defines the configuration for a provider alias
20+
type Alias struct {
21+
ApiType string // The actual API type to use (openai, anthropic, etc.)
22+
BaseURL string // Default base URL for the provider
23+
TokenEnvVar string // Environment variable name for the API token
24+
}
25+
26+
// ProviderAliases maps provider names to their corresponding configurations
27+
var ProviderAliases = map[string]Alias{
28+
"requesty": {
29+
ApiType: "openai",
30+
BaseURL: "https://router.requesty.ai/v1",
31+
TokenEnvVar: "REQUESTY_API_KEY",
32+
},
33+
"azure": {
34+
ApiType: "openai",
35+
TokenEnvVar: "AZURE_API_KEY",
36+
},
37+
}
38+
1939
// Provider defines the interface for model providers
2040
type Provider interface {
2141
// ID returns the model provider ID
@@ -37,21 +57,69 @@ type Provider interface {
3757
func New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) {
3858
slog.Debug("Creating model provider", "type", cfg.Provider, "model", cfg.Model)
3959

40-
switch cfg.Provider {
60+
// Apply provider alias defaults to the config
61+
enhancedCfg := applyProviderDefaults(cfg)
62+
apiType := ""
63+
if alias, exists := ProviderAliases[cfg.Provider]; exists {
64+
apiType = alias.ApiType
65+
}
66+
67+
// Resolve the actual API type from aliases or direct specification
68+
providerType := resolveProviderType(cfg.Provider, apiType)
69+
70+
switch providerType {
4171
case "openai":
42-
return openai.NewClient(ctx, cfg, env, opts...)
72+
return openai.NewClient(ctx, enhancedCfg, env, opts...)
4373

4474
case "anthropic":
45-
return anthropic.NewClient(ctx, cfg, env, opts...)
75+
return anthropic.NewClient(ctx, enhancedCfg, env, opts...)
4676

4777
case "google":
48-
return gemini.NewClient(ctx, cfg, env, opts...)
78+
return gemini.NewClient(ctx, enhancedCfg, env, opts...)
4979

5080
case "dmr":
51-
return dmr.NewClient(ctx, cfg, opts...)
81+
return dmr.NewClient(ctx, enhancedCfg, opts...)
5282

5383
default:
54-
slog.Error("Unknown provider type", "type", cfg.Provider)
55-
return nil, fmt.Errorf("unknown provider type: %s", cfg.Provider)
84+
slog.Error("Unknown provider type", "type", providerType)
85+
return nil, fmt.Errorf("unknown provider type: %s", providerType)
5686
}
5787
}
88+
89+
// applyProviderDefaults applies default configuration from provider aliases to the model config
90+
// This sets default base URLs and token keys if not already specified
91+
func applyProviderDefaults(cfg *latest.ModelConfig) *latest.ModelConfig {
92+
// Create a copy to avoid modifying the original
93+
enhancedCfg := *cfg
94+
95+
// Check if provider has alias configuration
96+
if alias, exists := ProviderAliases[cfg.Provider]; exists {
97+
// Set default base URL if not already specified
98+
if enhancedCfg.BaseURL == "" && alias.BaseURL != "" {
99+
enhancedCfg.BaseURL = alias.BaseURL
100+
}
101+
102+
// Set default token key if not already specified
103+
if enhancedCfg.TokenKey == "" && alias.TokenEnvVar != "" {
104+
enhancedCfg.TokenKey = alias.TokenEnvVar
105+
}
106+
}
107+
108+
return &enhancedCfg
109+
}
110+
111+
// resolveProviderType resolves the actual API type from the provider name and optional apiType
112+
func resolveProviderType(provider, apiType string) string {
113+
// If apiType is explicitly provided, use it
114+
if apiType != "" {
115+
return apiType
116+
}
117+
118+
// Check if provider has an alias mapping
119+
if resolved, exists := ProviderAliases[provider]; exists {
120+
return resolved.ApiType
121+
}
122+
123+
// Fall back to the provider name itself
124+
return provider
125+
}

pkg/modelsdev/store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
120120
id = actualID
121121
}
122122

123-
parts := strings.Split(id, "/")
123+
parts := strings.SplitN(id, "/", 2)
124124
if len(parts) != 2 {
125125
return nil, fmt.Errorf("invalid model ID: %q", id)
126126
}

pkg/teamloader/teamloader.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,23 @@ func checkRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environme
8585

8686
// Models
8787
if runtimeConfig.ModelsGateway == "" {
88-
for _, model := range cfg.Models {
89-
switch model.Provider {
90-
case "openai":
91-
requiredEnv["OPENAI_API_KEY"] = true
92-
case "anthropic":
93-
requiredEnv["ANTHROPIC_API_KEY"] = true
94-
case "google":
95-
requiredEnv["GOOGLE_API_KEY"] = true
88+
for name := range cfg.Models {
89+
model := cfg.Models[name]
90+
// Use the token environment variable from the alias if available
91+
if alias, exists := provider.ProviderAliases[model.Provider]; exists {
92+
if alias.TokenEnvVar != "" {
93+
requiredEnv[alias.TokenEnvVar] = true
94+
}
95+
} else {
96+
// Fallback to hardcoded mappings for unknown providers
97+
switch model.Provider {
98+
case "openai":
99+
requiredEnv["OPENAI_API_KEY"] = true
100+
case "anthropic":
101+
requiredEnv["ANTHROPIC_API_KEY"] = true
102+
case "google":
103+
requiredEnv["GOOGLE_API_KEY"] = true
104+
}
96105
}
97106
}
98107

0 commit comments

Comments
 (0)