Skip to content

Commit 535b3bc

Browse files
committed
Refactor: reuse shared message type in claude client
1 parent 5f65d55 commit 535b3bc

4 files changed

Lines changed: 63 additions & 42 deletions

File tree

cmd/cli/store/store.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,17 @@ func ChangeDefault(Model types.LLMProvider) error {
244244
}
245245
}
246246

247+
found := false
248+
for _, p := range cfg.LLMProviders {
249+
if p.LLM == Model {
250+
found = true
251+
break
252+
}
253+
}
254+
if !found {
255+
return fmt.Errorf("cannot set default to %s: no saved entry", Model.String())
256+
}
257+
247258
cfg.Default = Model
248259

249260
data, err = json.MarshalIndent(cfg, "", " ")
@@ -335,12 +346,18 @@ func UpdateAPIKey(Model types.LLMProvider, APIKey string) error {
335346
}
336347
}
337348

349+
updated := false
338350
for i, p := range cfg.LLMProviders {
339351
if p.LLM == Model {
340352
cfg.LLMProviders[i].APIKey = APIKey
353+
updated = true
341354
}
342355
}
343356

357+
if !updated {
358+
return fmt.Errorf("no saved entry for %s to update", Model.String())
359+
}
360+
344361
data, err = json.MarshalIndent(cfg, "", " ")
345362
if err != nil {
346363
return err

internal/claude/claude.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,9 @@ import (
1212

1313
// ClaudeRequest describes the payload sent to Anthropic's Claude messages API.
1414
type ClaudeRequest struct {
15-
Model string `json:"model"`
16-
Messages []Message `json:"messages"`
17-
MaxTokens int `json:"max_tokens"`
18-
}
19-
20-
// Message represents a single role/content pair exchanged with Claude.
21-
type Message struct {
22-
Role string `json:"role"`
23-
Content string `json:"content"`
15+
Model string `json:"model"`
16+
Messages []types.Message `json:"messages"`
17+
MaxTokens int `json:"max_tokens"`
2418
}
2519

2620
// ClaudeResponse captures the subset of fields used from Anthropic responses.
@@ -41,7 +35,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string,
4135
reqBody := ClaudeRequest{
4236
Model: "claude-3-5-sonnet-20241022",
4337
MaxTokens: 200,
44-
Messages: []Message{
38+
Messages: []types.Message{
4539
{
4640
Role: "user",
4741
Content: prompt,

internal/grok/grok.go

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,38 @@ import (
55
"crypto/tls"
66
"encoding/json"
77
"fmt"
8-
"io/ioutil"
8+
"io"
99
"net/http"
10+
"sync"
1011
"time"
1112

1213
"github.com/dfanso/commit-msg/pkg/types"
1314
)
1415

16+
var (
17+
grokClientOnce sync.Once
18+
grokClient *http.Client
19+
)
20+
21+
func getHTTPClient() *http.Client {
22+
grokClientOnce.Do(func() {
23+
transport := &http.Transport{
24+
TLSHandshakeTimeout: 10 * time.Second,
25+
MaxIdleConns: 10,
26+
IdleConnTimeout: 30 * time.Second,
27+
DisableCompression: true,
28+
TLSClientConfig: &tls.Config{
29+
InsecureSkipVerify: false,
30+
},
31+
}
32+
grokClient = &http.Client{
33+
Timeout: 30 * time.Second,
34+
Transport: transport,
35+
}
36+
})
37+
return grokClient
38+
}
39+
1540
// GenerateCommitMessage calls X.AI's Grok API to create a commit message from
1641
// the provided Git diff and generation options.
1742
func GenerateCommitMessage(config *types.Config, changes string, apiKey string, opts *types.GenerationOptions) (string, error) {
@@ -46,22 +71,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string,
4671
req.Header.Set("Content-Type", "application/json")
4772
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
4873

49-
// Configure HTTP client with improved TLS settings
50-
transport := &http.Transport{
51-
TLSHandshakeTimeout: 10 * time.Second,
52-
MaxIdleConns: 10,
53-
IdleConnTimeout: 30 * time.Second,
54-
DisableCompression: true,
55-
// Add TLS config to handle server name mismatch
56-
TLSClientConfig: &tls.Config{
57-
InsecureSkipVerify: false, // Keep this false for security
58-
},
59-
}
60-
61-
client := &http.Client{
62-
Timeout: 30 * time.Second,
63-
Transport: transport,
64-
}
74+
client := getHTTPClient()
6575
resp, err := client.Do(req)
6676
if err != nil {
6777
return "", err
@@ -70,7 +80,7 @@ func GenerateCommitMessage(config *types.Config, changes string, apiKey string,
7080

7181
// Check response status
7282
if resp.StatusCode != http.StatusOK {
73-
bodyBytes, _ := ioutil.ReadAll(resp.Body)
83+
bodyBytes, _ := io.ReadAll(resp.Body)
7484
return "", fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
7585
}
7686

pkg/types/types.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package types
22

3+
// LLMProvider identifies the large language model backend used to author
4+
// commit messages.
35
type LLMProvider string
46

57
const (
@@ -11,35 +13,34 @@ const (
1113
ProviderOllama LLMProvider = "Ollama"
1214
)
1315

16+
// String returns the provider identifier as a plain string.
1417
func (p LLMProvider) String() string {
1518
return string(p)
1619
}
1720

21+
// IsValid reports whether the provider is part of the supported set.
1822
func (p LLMProvider) IsValid() bool {
1923
switch p {
2024
case ProviderOpenAI, ProviderClaude, ProviderGemini, ProviderGrok, ProviderGroq, ProviderOllama:
2125
return true
22-
// LLMProvider identifies the large language model backend used to author
23-
// commit messages.
2426
default:
2527
return false
2628
}
2729
}
2830

31+
// GetSupportedProviders returns all available provider enums.
2932
func GetSupportedProviders() []LLMProvider {
3033
return []LLMProvider{
3134
ProviderOpenAI,
3235
ProviderClaude,
3336
ProviderGemini,
3437
ProviderGrok,
35-
// String returns the string form of the provider identifier.
3638
ProviderGroq,
3739
ProviderOllama,
3840
}
3941
}
4042

41-
// IsValid reports whether the provider is part of the supported set.
42-
43+
// GetSupportedProviderStrings returns the human-friendly names for providers.
4344
func GetSupportedProviderStrings() []string {
4445
providers := GetSupportedProviders()
4546
strings := make([]string, len(providers))
@@ -49,42 +50,39 @@ func GetSupportedProviderStrings() []string {
4950
return strings
5051
}
5152

52-
// GetSupportedProviders returns all available provider enums.
53-
53+
// ParseLLMProvider converts a string into an LLMProvider enum when supported.
5454
func ParseLLMProvider(s string) (LLMProvider, bool) {
5555
provider := LLMProvider(s)
5656
return provider, provider.IsValid()
5757
}
5858

59-
// Configuration structure
59+
// Config stores CLI-level configuration including named repositories.
6060
type Config struct {
6161
GrokAPI string `json:"grok_api"`
6262
Repos map[string]RepoConfig `json:"repos"`
6363
}
6464

65-
// GetSupportedProviderStrings returns the human-friendly names for providers.
66-
67-
// Repository configuration
65+
// RepoConfig tracks metadata for a configured Git repository.
6866
type RepoConfig struct {
6967
Path string `json:"path"`
7068
LastRun string `json:"last_run"`
7169
}
7270

73-
// Grok/X.AI API request structure
71+
// GrokRequest represents a chat completion request sent to X.AI's API.
7472
type GrokRequest struct {
75-
// ParseLLMProvider converts a string into an LLMProvider enum when supported.
7673
Messages []Message `json:"messages"`
7774
Model string `json:"model"`
7875
Stream bool `json:"stream"`
7976
Temperature float64 `json:"temperature"`
8077
}
8178

79+
// Message captures the role/content pairs exchanged with Grok.
8280
type Message struct {
8381
Role string `json:"role"`
8482
Content string `json:"content"`
8583
}
8684

87-
// Grok/X.AI API response structure
85+
// GrokResponse contains the relevant fields parsed from X.AI responses.
8886
type GrokResponse struct {
8987
Message Message `json:"message,omitempty"`
9088
Choices []Choice `json:"choices,omitempty"`
@@ -95,12 +93,14 @@ type GrokResponse struct {
9593
Usage UsageInfo `json:"usage,omitempty"`
9694
}
9795

96+
// Choice details a single response option returned by Grok.
9897
type Choice struct {
9998
Message Message `json:"message"`
10099
Index int `json:"index"`
101100
FinishReason string `json:"finish_reason"`
102101
}
103102

103+
// UsageInfo reports token usage statistics from Grok responses.
104104
type UsageInfo struct {
105105
PromptTokens int `json:"prompt_tokens"`
106106
CompletionTokens int `json:"completion_tokens"`

0 commit comments

Comments
 (0)