Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions client/adaptive_ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -211,10 +212,11 @@ type tpmEntry struct {
var _ Provider = (*AdaptiveRateLimitProvider)(nil)

// NewAdaptiveRateLimitProvider wraps inner with adaptive rate limiting.
// inner must not be nil. config may be zero-valued for sensible defaults.
func NewAdaptiveRateLimitProvider(inner Provider, config AdaptiveRateLimitConfig) *AdaptiveRateLimitProvider {
// inner must not be nil (an error is returned otherwise). config may be
// zero-valued for sensible defaults.
func NewAdaptiveRateLimitProvider(inner Provider, config AdaptiveRateLimitConfig) (*AdaptiveRateLimitProvider, error) {
if inner == nil {
panic("eyrie: NewAdaptiveRateLimitProvider inner provider must not be nil")
return nil, errors.New("eyrie: NewAdaptiveRateLimitProvider inner provider must not be nil")
}
if config.ThresholdPercent <= 0 {
config.ThresholdPercent = 10
Expand All @@ -233,7 +235,7 @@ func NewAdaptiveRateLimitProvider(inner Provider, config AdaptiveRateLimitConfig
tpmWindow: make([]tpmEntry, 0, 64),
tpmRemaining: -1,
tpmLimit: tpmLimit,
}
}, nil
}

// Name returns the inner provider name suffixed with "/adaptive-ratelimit".
Expand Down
61 changes: 36 additions & 25 deletions client/adaptive_ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,23 @@ func (m *mockProvider) StreamChat(ctx context.Context, messages []EyrieMessage,

func TestAdaptiveRateLimitProvider_Name(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{})
if p.Name() != "test/adaptive-ratelimit" {
t.Errorf("expected name 'test/adaptive-ratelimit', got %q", p.Name())
}
}

func TestAdaptiveRateLimitProvider_Ping(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{})
if err := p.Ping(context.Background()); err != nil {
t.Errorf("expected no error, got %v", err)
}
}

func TestAdaptiveRateLimitProvider_Chat(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{})

resp, err := p.Chat(context.Background(), nil, ChatOptions{})
if err != nil {
Expand All @@ -80,7 +80,7 @@ func TestAdaptiveRateLimitProvider_Chat(t *testing.T) {

func TestAdaptiveRateLimitProvider_ChatTracksUsage(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{})

// Make a few calls
for i := 0; i < 3; i++ {
Expand All @@ -105,7 +105,7 @@ func TestAdaptiveRateLimitProvider_ChatTracksUsage(t *testing.T) {
func TestAdaptiveRateLimitProvider_NearLimitThrottle(t *testing.T) {
// Set up a provider with a very low RPM limit (5 RPM)
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
RPMLimit: 5,
ThresholdPercent: 50, // throttle when <50% remaining (i.e., <3 remaining)
MaxDelay: 50 * time.Millisecond,
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestAdaptiveRateLimitProvider_NearLimitThrottle(t *testing.T) {
func TestAdaptiveRateLimitProvider_RPMExceeded(t *testing.T) {
// Set up with limit of 3, no delay (just error)
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
RPMLimit: 3,
ThresholdPercent: 10,
MaxDelay: 0, // don't delay, just error
Expand Down Expand Up @@ -173,7 +173,7 @@ func TestAdaptiveRateLimitProvider_RPMExceeded(t *testing.T) {

func TestAdaptiveRateLimitProvider_TPMTracking(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
TPMLimit: 500,
ThresholdPercent: 10,
})
Expand All @@ -194,7 +194,7 @@ func TestAdaptiveRateLimitProvider_TPMTracking(t *testing.T) {

func TestAdaptiveRateLimitProvider_StreamChat(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{})

result, err := p.StreamChat(context.Background(), nil, ChatOptions{})
if err != nil {
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestAdaptiveRateLimitProvider_StreamChat(t *testing.T) {

func TestAdaptiveRateLimitProvider_UpdateFromHeaders(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
HeaderExtractor: CommonHeaderExtractor,
})

Expand Down Expand Up @@ -253,7 +253,7 @@ func TestAdaptiveRateLimitProvider_UpdateFromHeaders(t *testing.T) {

func TestAdaptiveRateLimitProvider_AnthropicHeaders(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
HeaderExtractor: CommonHeaderExtractor,
})

Expand All @@ -279,7 +279,7 @@ func TestAdaptiveRateLimitProvider_AnthropicHeaders(t *testing.T) {

func TestAdaptiveRateLimitProvider_HeaderDrivenThrottle(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
RPMLimit: 100,
ThresholdPercent: 10,
MaxDelay: 50 * time.Millisecond,
Expand All @@ -304,7 +304,7 @@ func TestAdaptiveRateLimitProvider_HeaderDrivenThrottle(t *testing.T) {

func TestAdaptiveRateLimitProvider_ConcurrentSafety(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{})

var wg sync.WaitGroup
var errCount atomic.Int64
Expand Down Expand Up @@ -333,7 +333,7 @@ func TestAdaptiveRateLimitProvider_ConcurrentSafety(t *testing.T) {

func TestAdaptiveRateLimitProvider_ContextCancellation(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
RPMLimit: 1,
ThresholdPercent: 10,
MaxDelay: 10 * time.Second,
Expand Down Expand Up @@ -400,18 +400,15 @@ func TestParseResetTime(t *testing.T) {
}
}

func TestAdaptiveRateLimitProvider_NilInnerPanics(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic for nil inner provider")
}
}()
NewAdaptiveRateLimitProvider(nil, AdaptiveRateLimitConfig{})
func TestAdaptiveRateLimitProvider_NilInnerErrors(t *testing.T) {
if _, err := NewAdaptiveRateLimitProvider(nil, AdaptiveRateLimitConfig{}); err == nil {
t.Error("expected error for nil inner provider")
}
}

func TestAdaptiveRateLimitProvider_DefaultConfig(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{})

// With no limits set, calls should go through freely
for i := 0; i < 100; i++ {
Expand All @@ -432,7 +429,7 @@ func TestAdaptiveRateLimitProvider_DefaultConfig(t *testing.T) {

func TestAdaptiveRateLimitProvider_WindowExpiry(t *testing.T) {
inner := &mockProvider{name: "test"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(t, inner, AdaptiveRateLimitConfig{
RPMLimit: 2,
ThresholdPercent: 10,
})
Expand All @@ -458,7 +455,7 @@ func TestAdaptiveRateLimitProvider_WindowExpiry(t *testing.T) {

func BenchmarkAdaptiveRateLimitProvider_Chat(b *testing.B) {
inner := &mockProvider{name: "bench"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{})
p := mustAdaptiveRateLimitProvider(b, inner, AdaptiveRateLimitConfig{})
ctx := context.Background()

b.ResetTimer()
Expand All @@ -469,7 +466,7 @@ func BenchmarkAdaptiveRateLimitProvider_Chat(b *testing.B) {

func BenchmarkAdaptiveRateLimitProvider_ChatWithLimits(b *testing.B) {
inner := &mockProvider{name: "bench"}
p := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
p := mustAdaptiveRateLimitProvider(b, inner, AdaptiveRateLimitConfig{
RPMLimit: 100000,
TPMLimit: 1000000,
ThresholdPercent: 10,
Expand All @@ -487,13 +484,17 @@ func ExampleAdaptiveRateLimitProvider() {
var inner Provider = &mockProvider{name: "openai"}

// Wrap with adaptive rate limiting
provider := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
provider, err := NewAdaptiveRateLimitProvider(inner, AdaptiveRateLimitConfig{
RPMLimit: 60, // 60 requests per minute
TPMLimit: 90000, // 90k tokens per minute
ThresholdPercent: 10, // throttle when <10% remaining
MaxDelay: 5 * time.Second,
HeaderExtractor: CommonHeaderExtractor, // parse rate limit headers
})
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}

// Use the provider normally
resp, err := provider.Chat(context.Background(), []EyrieMessage{
Expand All @@ -511,3 +512,13 @@ func ExampleAdaptiveRateLimitProvider() {
fmt.Printf("Tokens used: %d/%d\n", status.TPMUsed, status.TPMLimit)
fmt.Printf("Throttle count: %d\n", status.ThrottleCount)
}

// mustAdaptiveRateLimitProvider constructs the provider, failing the test on error.
func mustAdaptiveRateLimitProvider(tb testing.TB, inner Provider, config AdaptiveRateLimitConfig) *AdaptiveRateLimitProvider {
tb.Helper()
p, err := NewAdaptiveRateLimitProvider(inner, config)
if err != nil {
tb.Fatalf("NewAdaptiveRateLimitProvider: %v", err)
}
return p
}
9 changes: 5 additions & 4 deletions client/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"context"
"errors"
"log/slog"
"sync"
"time"
Expand Down Expand Up @@ -44,15 +45,15 @@ type CallbackProvider struct {
var _ Provider = (*CallbackProvider)(nil)

// NewCallbackProvider wraps the given provider with callback support.
// The inner provider must not be nil.
func NewCallbackProvider(inner Provider) *CallbackProvider {
// The inner provider must not be nil; an error is returned otherwise.
func NewCallbackProvider(inner Provider) (*CallbackProvider, error) {
if inner == nil {
panic("eyrie: NewCallbackProvider inner provider must not be nil")
return nil, errors.New("eyrie: NewCallbackProvider inner provider must not be nil")
}
return &CallbackProvider{
inner: inner,
logger: slog.Default(),
}
}, nil
}

// SetLogger sets the logger used for panic-recovery messages.
Expand Down
Loading
Loading