Skip to content

Commit d174cd6

Browse files
committed
Consolidate openai and dmr stream adapters
dmr is openai compatible, so to minimize drift between the implmentations use the same stream adapter with both Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 3ddec4d commit d174cd6

3 files changed

Lines changed: 131 additions & 203 deletions

File tree

pkg/model/provider/dmr/adapter.go

Lines changed: 4 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,102 +4,10 @@ import (
44
"github.com/sashabaranov/go-openai"
55

66
"github.com/docker/cagent/pkg/chat"
7-
"github.com/docker/cagent/pkg/tools"
7+
"github.com/docker/cagent/pkg/model/provider/oaistream"
88
)
99

10-
// streamAdapter adapts the OpenAI stream to our interface
11-
type streamAdapter struct {
12-
stream *openai.ChatCompletionStream
13-
toolCalls map[int]string
14-
}
15-
16-
func newStreamAdapter(stream *openai.ChatCompletionStream) *streamAdapter {
17-
return &streamAdapter{
18-
stream: stream,
19-
toolCalls: make(map[int]string),
20-
}
21-
}
22-
23-
// Recv gets the next completion chunk
24-
func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
25-
openaiResponse, err := a.stream.Recv()
26-
if err != nil {
27-
return chat.MessageStreamResponse{}, err
28-
}
29-
30-
// Convert the OpenAI response to our generic format
31-
response := chat.MessageStreamResponse{
32-
ID: openaiResponse.ID,
33-
Object: openaiResponse.Object,
34-
Created: openaiResponse.Created,
35-
Model: openaiResponse.Model,
36-
Choices: make([]chat.MessageStreamChoice, len(openaiResponse.Choices)),
37-
}
38-
39-
if openaiResponse.Usage != nil {
40-
response.Usage = &chat.Usage{
41-
InputTokens: openaiResponse.Usage.PromptTokens,
42-
OutputTokens: openaiResponse.Usage.CompletionTokens,
43-
CachedOutputTokens: 0,
44-
}
45-
if openaiResponse.Usage.PromptTokensDetails != nil {
46-
response.Usage.CachedInputTokens = openaiResponse.Usage.PromptTokensDetails.CachedTokens
47-
}
48-
49-
response.Choices = append(response.Choices, chat.MessageStreamChoice{
50-
FinishReason: chat.FinishReasonStop,
51-
})
52-
}
53-
54-
// Convert the choices
55-
for i := range openaiResponse.Choices {
56-
choice := &openaiResponse.Choices[i]
57-
finishReason := chat.FinishReason(choice.FinishReason)
58-
59-
response.Choices[i] = chat.MessageStreamChoice{
60-
Index: choice.Index,
61-
FinishReason: finishReason,
62-
Delta: chat.MessageDelta{
63-
Role: choice.Delta.Role,
64-
Content: choice.Delta.Content,
65-
},
66-
}
67-
68-
// Convert function call if present
69-
if choice.Delta.FunctionCall != nil {
70-
response.Choices[i].Delta.FunctionCall = &tools.FunctionCall{
71-
Name: choice.Delta.FunctionCall.Name,
72-
Arguments: choice.Delta.FunctionCall.Arguments,
73-
}
74-
}
75-
76-
// Convert tool calls if present
77-
if len(choice.Delta.ToolCalls) > 0 {
78-
response.Choices[i].Delta.ToolCalls = make([]tools.ToolCall, len(choice.Delta.ToolCalls))
79-
for j, toolCall := range choice.Delta.ToolCalls {
80-
id := toolCall.ID
81-
if existing, ok := a.toolCalls[*toolCall.Index]; ok {
82-
id = existing
83-
} else {
84-
a.toolCalls[*toolCall.Index] = id
85-
}
86-
87-
response.Choices[i].Delta.ToolCalls[j] = tools.ToolCall{
88-
ID: id,
89-
Type: tools.ToolType(toolCall.Type),
90-
Function: tools.FunctionCall{
91-
Name: toolCall.Function.Name,
92-
Arguments: toolCall.Function.Arguments,
93-
},
94-
}
95-
}
96-
}
97-
}
98-
99-
return response, nil
100-
}
101-
102-
// Close closes the stream
103-
func (a *streamAdapter) Close() {
104-
a.stream.Close()
10+
// newStreamAdapter returns the shared OpenAI stream adapter implementation
11+
func newStreamAdapter(stream *openai.ChatCompletionStream) chat.MessageStream {
12+
return oaistream.NewStreamAdapter(stream)
10513
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package oaistream
2+
3+
/*
4+
This is a shared adapter for OpenAI-compatible streams.
5+
*/
6+
7+
import (
8+
"github.com/sashabaranov/go-openai"
9+
10+
"github.com/docker/cagent/pkg/chat"
11+
"github.com/docker/cagent/pkg/tools"
12+
)
13+
14+
// StreamAdapter adapts the OpenAI stream to our interface
15+
type StreamAdapter struct {
16+
stream *openai.ChatCompletionStream
17+
lastFinishReason chat.FinishReason
18+
toolCalls map[int]string
19+
}
20+
21+
func NewStreamAdapter(stream *openai.ChatCompletionStream) *StreamAdapter {
22+
return &StreamAdapter{
23+
stream: stream,
24+
toolCalls: make(map[int]string),
25+
}
26+
}
27+
28+
// Recv gets the next completion chunk
29+
func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
30+
openaiResponse, err := a.stream.Recv()
31+
if err != nil {
32+
return chat.MessageStreamResponse{}, err
33+
}
34+
35+
// Convert the OpenAI response to our generic format
36+
response := chat.MessageStreamResponse{
37+
ID: openaiResponse.ID,
38+
Object: openaiResponse.Object,
39+
Created: openaiResponse.Created,
40+
Model: openaiResponse.Model,
41+
Choices: make([]chat.MessageStreamChoice, len(openaiResponse.Choices)),
42+
}
43+
44+
if openaiResponse.Usage != nil {
45+
response.Usage = &chat.Usage{
46+
InputTokens: openaiResponse.Usage.PromptTokens,
47+
OutputTokens: openaiResponse.Usage.CompletionTokens,
48+
CachedInputTokens: 0,
49+
CachedOutputTokens: 0,
50+
}
51+
if openaiResponse.Usage.PromptTokensDetails != nil {
52+
response.Usage.CachedInputTokens = openaiResponse.Usage.PromptTokensDetails.CachedTokens
53+
}
54+
// Use the tracked finish reason instead of hardcoding stop
55+
finishReason := a.lastFinishReason
56+
if finishReason == "" {
57+
finishReason = chat.FinishReasonStop
58+
}
59+
response.Choices = append(response.Choices, chat.MessageStreamChoice{
60+
FinishReason: finishReason,
61+
})
62+
}
63+
64+
// Convert the choices
65+
for i := range openaiResponse.Choices {
66+
choice := &openaiResponse.Choices[i]
67+
if choice.FinishReason == openai.FinishReasonStop {
68+
choice.FinishReason = openai.FinishReasonNull
69+
}
70+
71+
finishReason := chat.FinishReason(choice.FinishReason)
72+
// Track the finish reason for when we get usage info
73+
if finishReason != chat.FinishReasonNull && finishReason != "" {
74+
a.lastFinishReason = finishReason
75+
}
76+
77+
response.Choices[i] = chat.MessageStreamChoice{
78+
Index: choice.Index,
79+
FinishReason: finishReason,
80+
Delta: chat.MessageDelta{
81+
Role: choice.Delta.Role,
82+
Content: choice.Delta.Content,
83+
},
84+
}
85+
86+
// Convert function call if present
87+
if choice.Delta.FunctionCall != nil {
88+
response.Choices[i].Delta.FunctionCall = &tools.FunctionCall{
89+
Name: choice.Delta.FunctionCall.Name,
90+
Arguments: choice.Delta.FunctionCall.Arguments,
91+
}
92+
}
93+
94+
// Convert tool calls if present
95+
if len(choice.Delta.ToolCalls) > 0 {
96+
response.Choices[i].Delta.ToolCalls = make([]tools.ToolCall, len(choice.Delta.ToolCalls))
97+
for j, toolCall := range choice.Delta.ToolCalls {
98+
id := toolCall.ID
99+
if existing, ok := a.toolCalls[*toolCall.Index]; ok {
100+
id = existing
101+
} else {
102+
a.toolCalls[*toolCall.Index] = id
103+
}
104+
105+
response.Choices[i].Delta.ToolCalls[j] = tools.ToolCall{
106+
ID: id,
107+
Type: tools.ToolType(toolCall.Type),
108+
Function: tools.FunctionCall{
109+
Name: toolCall.Function.Name,
110+
Arguments: toolCall.Function.Arguments,
111+
},
112+
}
113+
}
114+
}
115+
}
116+
117+
return response, nil
118+
}
119+
120+
// Close closes the stream
121+
func (a *StreamAdapter) Close() {
122+
a.stream.Close()
123+
}

pkg/model/provider/openai/adapter.go

Lines changed: 4 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -4,113 +4,10 @@ import (
44
"github.com/sashabaranov/go-openai"
55

66
"github.com/docker/cagent/pkg/chat"
7-
"github.com/docker/cagent/pkg/tools"
7+
"github.com/docker/cagent/pkg/model/provider/oaistream"
88
)
99

10-
// streamAdapter adapts the OpenAI stream to our interface
11-
type streamAdapter struct {
12-
stream *openai.ChatCompletionStream
13-
lastFinishReason chat.FinishReason
14-
toolCalls map[int]string
15-
}
16-
17-
func newStreamAdapter(stream *openai.ChatCompletionStream) *streamAdapter {
18-
return &streamAdapter{
19-
stream: stream,
20-
toolCalls: make(map[int]string),
21-
}
22-
}
23-
24-
// Recv gets the next completion chunk
25-
func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
26-
openaiResponse, err := a.stream.Recv()
27-
if err != nil {
28-
return chat.MessageStreamResponse{}, err
29-
}
30-
31-
// Convert the OpenAI response to our generic format
32-
response := chat.MessageStreamResponse{
33-
ID: openaiResponse.ID,
34-
Object: openaiResponse.Object,
35-
Created: openaiResponse.Created,
36-
Model: openaiResponse.Model,
37-
Choices: make([]chat.MessageStreamChoice, len(openaiResponse.Choices)),
38-
}
39-
40-
if openaiResponse.Usage != nil {
41-
response.Usage = &chat.Usage{
42-
InputTokens: openaiResponse.Usage.PromptTokens,
43-
OutputTokens: openaiResponse.Usage.CompletionTokens,
44-
CachedInputTokens: openaiResponse.Usage.PromptTokensDetails.CachedTokens,
45-
CachedOutputTokens: 0,
46-
}
47-
// Use the tracked finish reason instead of hardcoding stop
48-
finishReason := a.lastFinishReason
49-
if finishReason == "" {
50-
finishReason = chat.FinishReasonStop
51-
}
52-
response.Choices = append(response.Choices, chat.MessageStreamChoice{
53-
FinishReason: finishReason,
54-
})
55-
}
56-
57-
// Convert the choices
58-
for i := range openaiResponse.Choices {
59-
choice := &openaiResponse.Choices[i]
60-
if choice.FinishReason == openai.FinishReasonStop {
61-
choice.FinishReason = openai.FinishReasonNull
62-
}
63-
64-
finishReason := chat.FinishReason(choice.FinishReason)
65-
// Track the finish reason for when we get usage info
66-
if finishReason != chat.FinishReasonNull && finishReason != "" {
67-
a.lastFinishReason = finishReason
68-
}
69-
70-
response.Choices[i] = chat.MessageStreamChoice{
71-
Index: choice.Index,
72-
FinishReason: finishReason,
73-
Delta: chat.MessageDelta{
74-
Role: choice.Delta.Role,
75-
Content: choice.Delta.Content,
76-
},
77-
}
78-
79-
// Convert function call if present
80-
if choice.Delta.FunctionCall != nil {
81-
response.Choices[i].Delta.FunctionCall = &tools.FunctionCall{
82-
Name: choice.Delta.FunctionCall.Name,
83-
Arguments: choice.Delta.FunctionCall.Arguments,
84-
}
85-
}
86-
87-
// Convert tool calls if present
88-
if len(choice.Delta.ToolCalls) > 0 {
89-
response.Choices[i].Delta.ToolCalls = make([]tools.ToolCall, len(choice.Delta.ToolCalls))
90-
for j, toolCall := range choice.Delta.ToolCalls {
91-
id := toolCall.ID
92-
if existing, ok := a.toolCalls[*toolCall.Index]; ok {
93-
id = existing
94-
} else {
95-
a.toolCalls[*toolCall.Index] = id
96-
}
97-
98-
response.Choices[i].Delta.ToolCalls[j] = tools.ToolCall{
99-
ID: id,
100-
Type: tools.ToolType(toolCall.Type),
101-
Function: tools.FunctionCall{
102-
Name: toolCall.Function.Name,
103-
Arguments: toolCall.Function.Arguments,
104-
},
105-
}
106-
}
107-
}
108-
}
109-
110-
return response, nil
111-
}
112-
113-
// Close closes the stream
114-
func (a *streamAdapter) Close() {
115-
a.stream.Close()
10+
// newStreamAdapter returns the shared OpenAI stream adapter implementation
11+
func newStreamAdapter(stream *openai.ChatCompletionStream) chat.MessageStream {
12+
return oaistream.NewStreamAdapter(stream)
11613
}

0 commit comments

Comments
 (0)