Skip to content

Commit e4cbc97

Browse files
authored
Merge pull request #1647 from dgageot/fix-429-openai
Preserve 429 error details on OpenAI
2 parents 48d3846 + cd8bcb1 commit e4cbc97

3 files changed

Lines changed: 301 additions & 0 deletions

File tree

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package oaistream
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"net/http"
8+
9+
"github.com/openai/openai-go/v3/option"
10+
)
11+
12+
// ErrorBodyMiddleware returns an OpenAI SDK middleware that preserves full
13+
// error details in HTTP error responses.
14+
//
15+
// The OpenAI SDK extracts only the "error" field from error response bodies
16+
// (via gjson). When a provider returns a body without an "error" object
17+
// (e.g. a string "error" field, plain text, or a different JSON structure),
18+
// the details are silently lost. This middleware rewrites such responses into
19+
// {"error": <original body>} so the SDK preserves the full content.
20+
func ErrorBodyMiddleware() option.Middleware {
21+
return func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
22+
resp, err := next(req)
23+
if err != nil || resp == nil || resp.StatusCode < 400 {
24+
return resp, err
25+
}
26+
27+
body, err := io.ReadAll(resp.Body)
28+
resp.Body.Close()
29+
if err != nil || hasErrorObject(body) {
30+
resp.Body = io.NopCloser(bytes.NewReader(body))
31+
return resp, nil
32+
}
33+
34+
wrapped := wrapErrorBody(body, resp.StatusCode)
35+
resp.Body = io.NopCloser(bytes.NewReader(wrapped))
36+
resp.ContentLength = int64(len(wrapped))
37+
return resp, nil
38+
}
39+
}
40+
41+
// hasErrorObject reports whether body is a JSON object with an "error" key
42+
// whose value is itself a JSON object — the format the OpenAI SDK expects.
43+
func hasErrorObject(body []byte) bool {
44+
var raw map[string]json.RawMessage
45+
if err := json.Unmarshal(body, &raw); err != nil {
46+
return false
47+
}
48+
errVal, ok := raw["error"]
49+
if !ok {
50+
return false
51+
}
52+
return len(bytes.TrimLeft(errVal, " \t\n\r")) > 0 && bytes.TrimLeft(errVal, " \t\n\r")[0] == '{'
53+
}
54+
55+
// wrapErrorBody produces {"error": <body>} when body is valid JSON, or
56+
// {"error": {"message": "<body>"}} otherwise, so the SDK's gjson extraction
57+
// always finds useful content.
58+
func wrapErrorBody(body []byte, statusCode int) []byte {
59+
if len(body) == 0 {
60+
body = []byte(http.StatusText(statusCode))
61+
}
62+
if json.Valid(body) {
63+
return append(append([]byte(`{"error":`), body...), '}')
64+
}
65+
wrapped, err := json.Marshal(map[string]any{
66+
"error": map[string]any{"message": string(body)},
67+
})
68+
if err != nil {
69+
return body
70+
}
71+
return wrapped
72+
}
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
package oaistream
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
"net/http"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestHasErrorObject(t *testing.T) {
14+
t.Parallel()
15+
16+
tests := []struct {
17+
name string
18+
body string
19+
want bool
20+
}{
21+
{"standard openai error", `{"error":{"message":"rate limit","type":"tokens"}}`, true},
22+
{"error is a string", `{"error":"Rate limit exceeded","retryAfterMs":30000}`, false},
23+
{"error field null", `{"error":null}`, false},
24+
{"error is a number", `{"error":429}`, false},
25+
{"error is a boolean", `{"error":true}`, false},
26+
{"error is an array", `{"error":["rate limit"]}`, false},
27+
{"no error field", `{"message":"rate limit exceeded"}`, false},
28+
{"plain text", `rate limit exceeded`, false},
29+
{"empty body", ``, false},
30+
{"empty object", `{}`, false},
31+
}
32+
33+
for _, tt := range tests {
34+
t.Run(tt.name, func(t *testing.T) {
35+
t.Parallel()
36+
got := hasErrorObject([]byte(tt.body))
37+
assert.Equal(t, tt.want, got)
38+
})
39+
}
40+
}
41+
42+
func TestWrapErrorBody(t *testing.T) {
43+
t.Parallel()
44+
45+
t.Run("json body is preserved verbatim as error value", func(t *testing.T) {
46+
t.Parallel()
47+
body := `{"consumed":24000000000000,"error":"Rate limit exceeded","limit":50000000000000}`
48+
wrapped := wrapErrorBody([]byte(body), http.StatusTooManyRequests)
49+
assert.JSONEq(t, `{"error":`+body+`}`, string(wrapped))
50+
})
51+
52+
t.Run("json without error field", func(t *testing.T) {
53+
t.Parallel()
54+
body := `{"message":"quota exceeded","retry_after":30}`
55+
wrapped := wrapErrorBody([]byte(body), http.StatusTooManyRequests)
56+
assert.JSONEq(t, `{"error":`+body+`}`, string(wrapped))
57+
})
58+
59+
t.Run("plain text body wrapped as message", func(t *testing.T) {
60+
t.Parallel()
61+
wrapped := wrapErrorBody([]byte("rate limit exceeded"), http.StatusTooManyRequests)
62+
var parsed struct {
63+
Error struct {
64+
Message string `json:"message"`
65+
} `json:"error"`
66+
}
67+
require.NoError(t, json.Unmarshal(wrapped, &parsed))
68+
assert.Equal(t, "rate limit exceeded", parsed.Error.Message)
69+
})
70+
71+
t.Run("empty body uses status text", func(t *testing.T) {
72+
t.Parallel()
73+
wrapped := wrapErrorBody(nil, http.StatusTooManyRequests)
74+
// "Too Many Requests" is not valid JSON, so it gets message-wrapped
75+
var parsed struct {
76+
Error struct {
77+
Message string `json:"message"`
78+
} `json:"error"`
79+
}
80+
require.NoError(t, json.Unmarshal(wrapped, &parsed))
81+
assert.Equal(t, "Too Many Requests", parsed.Error.Message)
82+
})
83+
}
84+
85+
func TestErrorBodyMiddleware(t *testing.T) {
86+
t.Parallel()
87+
88+
middleware := ErrorBodyMiddleware()
89+
90+
t.Run("passes through successful responses", func(t *testing.T) {
91+
t.Parallel()
92+
93+
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://example.com/v1/chat/completions", http.NoBody)
94+
require.NoError(t, err)
95+
96+
next := func(_ *http.Request) (*http.Response, error) {
97+
return &http.Response{
98+
StatusCode: http.StatusOK,
99+
Body: io.NopCloser(nil),
100+
}, nil
101+
}
102+
103+
resp, err := middleware(req, next)
104+
require.NoError(t, err)
105+
assert.Equal(t, http.StatusOK, resp.StatusCode)
106+
})
107+
108+
t.Run("passes through standard OpenAI errors", func(t *testing.T) {
109+
t.Parallel()
110+
111+
originalBody := `{"error":{"message":"Rate limit exceeded","type":"rate_limit","code":"rate_limit_exceeded"}}`
112+
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://example.com/v1/chat/completions", http.NoBody)
113+
require.NoError(t, err)
114+
115+
next := func(_ *http.Request) (*http.Response, error) {
116+
return &http.Response{
117+
StatusCode: http.StatusTooManyRequests,
118+
Body: io.NopCloser(newStringReader(originalBody)),
119+
}, nil
120+
}
121+
122+
resp, err := middleware(req, next)
123+
require.NoError(t, err)
124+
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
125+
126+
body, err := io.ReadAll(resp.Body)
127+
require.NoError(t, err)
128+
assert.JSONEq(t, originalBody, string(body))
129+
})
130+
131+
t.Run("wraps non-standard error bodies", func(t *testing.T) {
132+
t.Parallel()
133+
134+
originalBody := `{"message":"You have exceeded your rate limit","retry_after":30}`
135+
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://example.com/v1/chat/completions", http.NoBody)
136+
require.NoError(t, err)
137+
138+
next := func(_ *http.Request) (*http.Response, error) {
139+
return &http.Response{
140+
StatusCode: http.StatusTooManyRequests,
141+
Body: io.NopCloser(newStringReader(originalBody)),
142+
}, nil
143+
}
144+
145+
resp, err := middleware(req, next)
146+
require.NoError(t, err)
147+
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
148+
149+
body, err := io.ReadAll(resp.Body)
150+
require.NoError(t, err)
151+
assert.JSONEq(t, `{"error":`+originalBody+`}`, string(body))
152+
})
153+
154+
t.Run("wraps error bodies where error is a string not an object", func(t *testing.T) {
155+
t.Parallel()
156+
157+
originalBody := `{"consumed":24000000000000,"error":"Rate limit exceeded","limit":50000000000000,"remaining":-621927204133070,"resetTime":1803944783,"retryAfterMs":33315396472,"retryAfterSeconds":33315396}`
158+
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://example.com/v1/chat/completions", http.NoBody)
159+
require.NoError(t, err)
160+
161+
next := func(_ *http.Request) (*http.Response, error) {
162+
return &http.Response{
163+
StatusCode: http.StatusTooManyRequests,
164+
Body: io.NopCloser(newStringReader(originalBody)),
165+
}, nil
166+
}
167+
168+
resp, err := middleware(req, next)
169+
require.NoError(t, err)
170+
assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
171+
172+
body, err := io.ReadAll(resp.Body)
173+
require.NoError(t, err)
174+
// The full original body is placed as the "error" value
175+
assert.JSONEq(t, `{"error":`+originalBody+`}`, string(body))
176+
})
177+
178+
t.Run("wraps plain text error bodies", func(t *testing.T) {
179+
t.Parallel()
180+
181+
originalBody := "Service temporarily unavailable"
182+
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://example.com/v1/chat/completions", http.NoBody)
183+
require.NoError(t, err)
184+
185+
next := func(_ *http.Request) (*http.Response, error) {
186+
return &http.Response{
187+
StatusCode: http.StatusServiceUnavailable,
188+
Body: io.NopCloser(newStringReader(originalBody)),
189+
}, nil
190+
}
191+
192+
resp, err := middleware(req, next)
193+
require.NoError(t, err)
194+
195+
body, err := io.ReadAll(resp.Body)
196+
require.NoError(t, err)
197+
198+
var parsed struct {
199+
Error struct {
200+
Message string `json:"message"`
201+
} `json:"error"`
202+
}
203+
require.NoError(t, json.Unmarshal(body, &parsed))
204+
assert.Equal(t, originalBody, parsed.Error.Message)
205+
})
206+
}
207+
208+
func newStringReader(s string) io.Reader {
209+
return io.NopCloser(newBytesReader([]byte(s)))
210+
}
211+
212+
func newBytesReader(b []byte) io.Reader {
213+
return &bytesReader{data: b}
214+
}
215+
216+
type bytesReader struct {
217+
data []byte
218+
pos int
219+
}
220+
221+
func (r *bytesReader) Read(p []byte) (n int, err error) {
222+
if r.pos >= len(r.data) {
223+
return 0, io.EOF
224+
}
225+
n = copy(p, r.data[r.pos:])
226+
r.pos += n
227+
return n, nil
228+
}

pkg/model/provider/openai/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
128128
option.WithAPIKey(authToken),
129129
option.WithBaseURL(baseURL),
130130
option.WithHTTPClient(httpclient.NewHTTPClient(httpOptions...)),
131+
option.WithMiddleware(oaistream.ErrorBodyMiddleware()),
131132
)
132133

133134
return &client, nil

0 commit comments

Comments
 (0)