Skip to content

Commit 7eb41ff

Browse files
authored
Merge pull request #2018 from dgageot/board/unify-retry-logic-d5b3914d
Unify streamAdapter/betaStreamAdapter retry logic into generic retryableStream
2 parents 050398e + fbf4e10 commit 7eb41ff

5 files changed

Lines changed: 69 additions & 58 deletions

File tree

pkg/model/provider/anthropic/adapter.go

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7-
"io"
87
"net/http"
98
"strconv"
109
"strings"
@@ -19,19 +18,16 @@ import (
1918

2019
// streamAdapter adapts the Anthropic stream to our interface
2120
type streamAdapter struct {
22-
stream *ssestream.Stream[anthropic.MessageStreamEventUnion]
23-
trackUsage bool
24-
toolCall bool
25-
toolID string
26-
// For single retry on context length error
27-
retryFn func() *streamAdapter
28-
retried bool
21+
retryableStream[anthropic.MessageStreamEventUnion]
22+
trackUsage bool
23+
toolCall bool
24+
toolID string
2925
getResponseTrailer func() http.Header
3026
}
3127

3228
func (c *Client) newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter {
3329
return &streamAdapter{
34-
stream: stream,
30+
retryableStream: retryableStream[anthropic.MessageStreamEventUnion]{stream: stream},
3531
trackUsage: trackUsage,
3632
getResponseTrailer: c.getResponseTrailer,
3733
}
@@ -72,21 +68,9 @@ func isContextLengthError(err error) bool {
7268

7369
// Recv gets the next completion chunk
7470
func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
75-
if !a.stream.Next() {
76-
err := a.stream.Err()
77-
// Single retry on context length error
78-
if err != nil && !a.retried && a.retryFn != nil && isContextLengthError(err) {
79-
a.retried = true
80-
if retry := a.retryFn(); retry != nil {
81-
a.stream.Close()
82-
a.stream = retry.stream
83-
return a.Recv()
84-
}
85-
}
86-
if err != nil {
87-
return chat.MessageStreamResponse{}, err
88-
}
89-
return chat.MessageStreamResponse{}, io.EOF
71+
ok, err := a.next()
72+
if !ok {
73+
return chat.MessageStreamResponse{}, err
9074
}
9175

9276
event := a.stream.Current()
@@ -192,7 +176,5 @@ func parseHeaderInt64(headerValue string) int64 {
192176

193177
// Close closes the stream
194178
func (a *streamAdapter) Close() {
195-
if a.stream != nil {
196-
a.stream.Close()
197-
}
179+
a.stream.Close()
198180
}

pkg/model/provider/anthropic/beta_adapter.go

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package anthropic
22

33
import (
44
"fmt"
5-
"io"
65
"log/slog"
76
"net/http"
87

@@ -15,42 +14,27 @@ import (
1514

1615
// betaStreamAdapter adapts the Anthropic Beta stream to our interface
1716
type betaStreamAdapter struct {
18-
stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion]
19-
trackUsage bool
20-
toolCall bool
21-
toolID string
22-
// For single retry on context length error
23-
retryFn func() *betaStreamAdapter
24-
retried bool
17+
retryableStream[anthropic.BetaRawMessageStreamEventUnion]
18+
trackUsage bool
19+
toolCall bool
20+
toolID string
2521
getResponseTrailer func() http.Header
2622
}
2723

2824
// newBetaStreamAdapter creates a new Beta stream adapter
2925
func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion], trackUsage bool) *betaStreamAdapter {
3026
return &betaStreamAdapter{
31-
stream: stream,
27+
retryableStream: retryableStream[anthropic.BetaRawMessageStreamEventUnion]{stream: stream},
3228
trackUsage: trackUsage,
3329
getResponseTrailer: c.getResponseTrailer,
3430
}
3531
}
3632

3733
// Recv gets the next completion chunk from the Beta stream
3834
func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
39-
if !a.stream.Next() {
40-
err := a.stream.Err()
41-
// Single retry on context length error
42-
if err != nil && !a.retried && a.retryFn != nil && isContextLengthError(err) {
43-
a.retried = true
44-
if retry := a.retryFn(); retry != nil {
45-
a.stream.Close()
46-
a.stream = retry.stream
47-
return a.Recv()
48-
}
49-
}
50-
if err != nil {
51-
return chat.MessageStreamResponse{}, err
52-
}
53-
return chat.MessageStreamResponse{}, io.EOF
35+
ok, err := a.next()
36+
if !ok {
37+
return chat.MessageStreamResponse{}, err
5438
}
5539

5640
event := a.stream.Current()
@@ -137,7 +121,5 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
137121

138122
// Close closes the Beta stream
139123
func (a *betaStreamAdapter) Close() {
140-
if a.stream != nil {
141-
a.stream.Close()
142-
}
124+
a.stream.Close()
143125
}

pkg/model/provider/anthropic/beta_client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func (c *Client) createBetaStream(
128128
ad := c.newBetaStreamAdapter(stream, trackUsage)
129129

130130
// Set up single retry for context length errors
131-
ad.retryFn = func() *betaStreamAdapter {
131+
ad.retryFn = func() *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion] {
132132
used, err := countAnthropicTokensBeta(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools)
133133
if err != nil {
134134
slog.Warn("Failed to count tokens for retry, skipping", "error", err)
@@ -142,7 +142,7 @@ func (c *Client) createBetaStream(
142142
slog.Warn("Retrying with clamped max_tokens after context length error", "original", maxTokens, "clamped", newMaxTokens, "used", used)
143143
retryParams := params
144144
retryParams.MaxTokens = newMaxTokens
145-
return c.newBetaStreamAdapter(client.Beta.Messages.NewStreaming(ctx, retryParams), trackUsage)
145+
return client.Beta.Messages.NewStreaming(ctx, retryParams)
146146
}
147147

148148
slog.Debug("Anthropic Beta API chat completion stream created successfully", "model", c.ModelConfig.Model)

pkg/model/provider/anthropic/client.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/anthropics/anthropic-sdk-go"
1616
"github.com/anthropics/anthropic-sdk-go/option"
1717
"github.com/anthropics/anthropic-sdk-go/packages/param"
18+
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
1819

1920
"github.com/docker/docker-agent/pkg/chat"
2021
"github.com/docker/docker-agent/pkg/config/latest"
@@ -349,7 +350,7 @@ func (c *Client) CreateChatCompletionStream(
349350
ad := c.newStreamAdapter(stream, trackUsage)
350351

351352
// Set up single retry for context length errors
352-
ad.retryFn = func() *streamAdapter {
353+
ad.retryFn = func() *ssestream.Stream[anthropic.MessageStreamEventUnion] {
353354
used, err := countAnthropicTokens(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools)
354355
if err != nil {
355356
slog.Warn("Failed to count tokens for retry, skipping", "error", err)
@@ -363,7 +364,7 @@ func (c *Client) CreateChatCompletionStream(
363364
slog.Warn("Retrying with clamped max_tokens after context length error", "original max_tokens", maxTokens, "clamped max_tokens", newMaxTokens, "used tokens", used)
364365
retryParams := params
365366
retryParams.MaxTokens = newMaxTokens
366-
return c.newStreamAdapter(client.Messages.NewStreaming(ctx, retryParams, betaHeader), trackUsage)
367+
return client.Messages.NewStreaming(ctx, retryParams, betaHeader)
367368
}
368369

369370
slog.Debug("Anthropic chat completion stream created successfully", "model", c.ModelConfig.Model)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package anthropic
2+
3+
import (
4+
"io"
5+
6+
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
7+
)
8+
9+
// retryableStream wraps an ssestream.Stream and adds a single-retry mechanism
10+
// for context length errors. Both the standard and Beta stream adapters embed
11+
// this to share the retry logic.
12+
type retryableStream[T any] struct {
13+
stream *ssestream.Stream[T]
14+
// retryFn, when non-nil, is called once on a context-length error.
15+
// It should return a new stream to use, or nil to skip retrying.
16+
retryFn func() *ssestream.Stream[T]
17+
retried bool
18+
}
19+
20+
// next moves the stream forward. If the stream is exhausted it returns
21+
// (false, io.EOF). If it encounters an error it attempts a single retry when
22+
// the error is a context-length error and a retryFn is configured.
23+
// On success it returns (true, nil).
24+
func (r *retryableStream[T]) next() (bool, error) {
25+
if r.stream.Next() {
26+
return true, nil
27+
}
28+
29+
err := r.stream.Err()
30+
if err != nil && !r.retried && r.retryFn != nil && isContextLengthError(err) {
31+
r.retried = true
32+
if newStream := r.retryFn(); newStream != nil {
33+
r.stream.Close()
34+
r.stream = newStream
35+
ok, err := r.next()
36+
if !ok && err != nil {
37+
r.stream.Close() // Clean up on retry failure
38+
}
39+
return ok, err
40+
}
41+
}
42+
if err != nil {
43+
return false, err
44+
}
45+
return false, io.EOF
46+
}

0 commit comments

Comments
 (0)