diff --git a/AGENTS.md b/AGENTS.md index e46cec3..84e3cba 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -67,7 +67,7 @@ c := New( ``` ### Error Hierarchy -`errors.go` defines `APIError` as the base struct with `StatusCode`, `Code`, `Message`, `Details`. Each HTTP status gets a wrapper type (`NotFoundError`, `RateLimitError`, etc.) that embeds `APIError`. All error types implement `Error()` and `Unwrap()` for `errors.Is`/`errors.As` compatibility: +`errors.go` defines `APIError` as the base struct with `StatusCode`, `Code`, `Message`, `Details`. Each HTTP status gets a wrapper type (`NotFoundError`, `RateLimitError`, etc.) that embeds `APIError`. Subtypes inherit `Error()` via embedding (only `RateLimitError` overrides it to append retry info) and implement `Unwrap()` for `errors.Is`/`errors.As` compatibility: ```go var notFound *NotFoundError if errors.As(err, ¬Found) { @@ -145,7 +145,7 @@ if !errors.As(err, ¬Found) { t.Error("expected NotFoundError") } - **Do not change**: `APIError.Error()` format string — tests assert exact string output - **Do not change**: JSON struct tags — they match the daemon's API contract - **Do not change**: `Unwrap()` implementations — `errors.As` depends on them for type matching -- **Safe to extend**: add new error types by creating a struct embedding `APIError`, adding to `parseAPIError` switch, and implementing `Error()` + `Unwrap()` +- **Safe to extend**: add new error types by creating a struct embedding `APIError` (which promotes `Error()`), adding to `parseAPIError` switch, and implementing `Unwrap()` - **Safe to extend**: add new client methods by following the `get()`/`post()` delegation pattern - **When adding streaming endpoints**: follow `ChatStream` pattern — set `Accept: text/event-stream`, check status before wrapping in `newStreamReader` - **Concurrency**: any new mutable state on `Agent` must be protected by `a.mu` diff --git a/agent.go b/agent.go index a62453b..07e375f 100644 --- a/agent.go +++ b/agent.go @@ -38,6 +38,13 @@ type MemoryConfig struct { // Agent wraps a Client with declarative configuration, providing a // simplified interface for conversational AI interactions. +// +// Concurrency: Agent is safe for concurrent use. The session ID is read and +// updated under an internal mutex. Each Chat or ChatStream call captures the +// session ID at the moment the request is built; a stream returned by +// ChatStream continues to use the session ID captured at call time even if a +// concurrent Chat call establishes a new session while the stream is being +// consumed. type Agent struct { client *Client config AgentConfig @@ -94,7 +101,15 @@ func (a *Agent) Chat(ctx context.Context, message string) (*ChatResponse, error) // ChatStream sends a message and returns a streaming response reader. // Note: streaming with tools is not automatically looped; use Chat for // full tool loop support. +// +// The session ID is captured under the agent's lock when the request is +// built, so the entire stream lifecycle uses that snapshot: a concurrent +// Chat call that mutates the agent's session ID does not affect an +// in-flight stream. func (a *Agent) ChatStream(ctx context.Context, message string) (*StreamReader, error) { + // Capture the session ID into the request under the lock. req holds the + // captured value by copy, so the stream is immune to later mutations of + // a.sessionID by concurrent Chat calls. a.mu.Lock() req := a.buildRequest(message) a.mu.Unlock() diff --git a/agent_test.go b/agent_test.go index 5b8773f..94217b2 100644 --- a/agent_test.go +++ b/agent_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "sync" "sync/atomic" "testing" ) @@ -197,6 +198,67 @@ func TestAgent_ChatStream(t *testing.T) { } } +// TestAgent_ConcurrentChatAndStream runs Chat and ChatStream concurrently +// (run with -race) to verify that ChatStream's session ID snapshot is not +// affected by concurrent Chat calls mutating a.sessionID, and that no data +// race exists between request building and session updates. +func TestAgent_ConcurrentChatAndStream(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Accept") == "text/event-stream" { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte("data: chunk\n\n")) + w.Write([]byte("event: done\ndata: {}\n\n")) + return + } + json.NewEncoder(w).Encode(ChatResponse{ + SessionID: "race-sess", + Response: "ok", + }) + })) + defer srv.Close() + + c := New(WithBaseURL(srv.URL)) + agent := NewAgent(c, AgentConfig{Model: "test"}) + + const iterations = 10 + var wg sync.WaitGroup + errs := make(chan error, iterations*2) + + for i := 0; i < iterations; i++ { + wg.Add(2) + go func() { + defer wg.Done() + if _, err := agent.Chat(context.Background(), "hello"); err != nil { + errs <- err + } + }() + go func() { + defer wg.Done() + stream, err := agent.ChatStream(context.Background(), "stream hello") + if err != nil { + errs <- err + return + } + defer stream.Close() + // Consume the whole stream while Chat calls mutate sessionID. + if _, err := stream.CollectText(context.Background()); err != nil { + errs <- err + } + }() + } + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("concurrent Chat/ChatStream error: %v", err) + } + + if got := agent.SessionID(); got != "race-sess" { + t.Errorf("SessionID = %q, want %q", got, "race-sess") + } +} + func TestNewAgent_Defaults(t *testing.T) { c := New() agent := NewAgent(c, AgentConfig{}) diff --git a/client.go b/client.go index 123b759..478f91e 100644 --- a/client.go +++ b/client.go @@ -42,6 +42,10 @@ func WithAPIKey(key string) ClientOption { } // New creates a new hawk SDK client. +// +// Note: the client performs no retries by default. Pass +// WithRetry(DefaultRetryConfig()) for production use to enable automatic +// retries with exponential backoff on transient failures. func New(opts ...ClientOption) *Client { c := &Client{ baseURL: defaultBaseURL, @@ -144,7 +148,10 @@ func (c *Client) DeleteSession(ctx context.Context, id string) error { } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + // The daemon returns 204 No Content on delete, but older daemon versions + // and intermediary proxies may respond with 200 OK instead. Accepting any + // 2xx keeps this defensive and consistent with post()'s success handling. + if resp.StatusCode/100 != 2 { return parseAPIError(resp) } return nil @@ -226,7 +233,9 @@ func (c *Client) post(ctx context.Context, path string, body interface{}, out in } defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { + // Accept any 2xx status: creation endpoints may return 201 Created + // and future endpoints may use other success codes. + if resp.StatusCode/100 != 2 { return parseAPIError(resp) } diff --git a/docs/architecture.md b/docs/architecture.md index e96458e..aed5f7f 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -52,16 +52,16 @@ c := hawksdk.New( health, err := c.Health(ctx) // 💬 Non-streaming chat -resp, err := c.Chat(ctx, hawksdk.ChatRequest{Message: "list files"}) +resp, err := c.Chat(ctx, hawksdk.ChatRequest{Prompt: "list files"}) // 📡 Streaming chat -stream, err := c.ChatStream(ctx, hawksdk.ChatRequest{Message: "explain this code"}) +stream, err := c.ChatStream(ctx, hawksdk.ChatRequest{Prompt: "explain this code"}) defer stream.Close() for { ev, err := stream.Next(); if err != nil { break }; fmt.Print(ev.Data) } // 📋 Sessions -sessions, _ := c.Sessions(ctx, hawksdk.ListOptions{Limit: 10}) -msgs, _ := c.Messages(ctx, sessionID, hawksdk.ListOptions{}) +sessions, _ := c.Sessions(ctx, &hawksdk.ListOptions{Limit: 10}) +msgs, _ := c.Messages(ctx, sessionID, nil) _ = c.DeleteSession(ctx, sessionID) // 📊 Stats @@ -73,7 +73,7 @@ stats, _ := c.Stats(ctx) ## 🤖 Agent (Higher-Level) ```go -agent := hawksdk.NewAgent(c, hawksdk.AgentConfig{SystemPrompt: "You are a Go expert"}) +agent := hawksdk.NewAgent(c, hawksdk.AgentConfig{Model: "claude-sonnet-4-5", MaxRounds: 5}) resp, _ := agent.Chat(ctx, "refactor this function") // Subsequent calls automatically continue the same session ``` diff --git a/errors.go b/errors.go index 84d4e5e..119afed 100644 --- a/errors.go +++ b/errors.go @@ -35,9 +35,6 @@ type BadRequestError struct { APIError } -// Error implements the error interface. -func (e *BadRequestError) Error() string { return e.APIError.Error() } - // Unwrap allows errors.Is/As to match the underlying APIError. func (e *BadRequestError) Unwrap() error { return &e.APIError } @@ -46,9 +43,6 @@ type AuthenticationError struct { APIError } -// Error implements the error interface. -func (e *AuthenticationError) Error() string { return e.APIError.Error() } - // Unwrap allows errors.Is/As to match the underlying APIError. func (e *AuthenticationError) Unwrap() error { return &e.APIError } @@ -57,9 +51,6 @@ type ForbiddenError struct { APIError } -// Error implements the error interface. -func (e *ForbiddenError) Error() string { return e.APIError.Error() } - // Unwrap allows errors.Is/As to match the underlying APIError. func (e *ForbiddenError) Unwrap() error { return &e.APIError } @@ -68,9 +59,6 @@ type NotFoundError struct { APIError } -// Error implements the error interface. -func (e *NotFoundError) Error() string { return e.APIError.Error() } - // Unwrap allows errors.Is/As to match the underlying APIError. func (e *NotFoundError) Unwrap() error { return &e.APIError } @@ -98,9 +86,6 @@ type InternalServerError struct { APIError } -// Error implements the error interface. -func (e *InternalServerError) Error() string { return e.APIError.Error() } - // Unwrap allows errors.Is/As to match the underlying APIError. func (e *InternalServerError) Unwrap() error { return &e.APIError } @@ -109,9 +94,6 @@ type ServiceUnavailableError struct { APIError } -// Error implements the error interface. -func (e *ServiceUnavailableError) Error() string { return e.APIError.Error() } - // Unwrap allows errors.Is/As to match the underlying APIError. func (e *ServiceUnavailableError) Unwrap() error { return &e.APIError } diff --git a/sessions_test.go b/sessions_test.go index 88f6fe9..2797c76 100644 --- a/sessions_test.go +++ b/sessions_test.go @@ -336,6 +336,29 @@ func TestCreateSession(t *testing.T) { } } +// TestCreateSession201 verifies that post() accepts any 2xx status, not +// just 200 OK — creation endpoints commonly return 201 Created. +func TestCreateSession201(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(SessionSummary{ + ID: "created-sess", + CWD: "/tmp", + }) + })) + defer srv.Close() + + c := New(WithBaseURL(srv.URL)) + resp, err := c.CreateSession(context.Background(), CreateSessionRequest{Name: "n"}) + if err != nil { + t.Fatalf("CreateSession() with 201 response error: %v", err) + } + if resp.ID != "created-sess" { + t.Errorf("ID = %q, want %q", resp.ID, "created-sess") + } +} + func TestCreateSessionEmptyBody(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req CreateSessionRequest diff --git a/stream_helpers.go b/stream_helpers.go index a816e74..38a3bf0 100644 --- a/stream_helpers.go +++ b/stream_helpers.go @@ -52,6 +52,12 @@ type ToolCallDelta struct { // CollectText consumes the entire stream and returns the concatenated text content. // It blocks until the stream ends or the context is cancelled. +// +// When the returned error is non-nil, the returned string may be a partial +// result: it contains all text collected up to the point the error occurred, +// which callers may use or discard as appropriate. The error returned is the +// first error encountered while consuming the stream — if the stream emitted +// an "error" event before a later read failure, the "error" event wins. func (sr *StreamReader) CollectText(ctx context.Context) (string, error) { var sb strings.Builder var firstErr error @@ -59,6 +65,9 @@ func (sr *StreamReader) CollectText(ctx context.Context) (string, error) { for { select { case <-ctx.Done(): + if firstErr != nil { + return sb.String(), firstErr + } return sb.String(), ctx.Err() default: } @@ -68,6 +77,11 @@ func (sr *StreamReader) CollectText(ctx context.Context) (string, error) { return sb.String(), firstErr } if err != nil { + // Preserve first-error semantics: an "error" event seen earlier + // takes precedence over a subsequent read failure. + if firstErr != nil { + return sb.String(), firstErr + } return sb.String(), err } diff --git a/workflow_test.go b/workflow_test.go index a7992aa..b4d2e3e 100644 --- a/workflow_test.go +++ b/workflow_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "testing" "time" ) @@ -142,6 +143,131 @@ func TestWorkflow_StepFailure(t *testing.T) { } } +// TestWorkflow_RetryBackoffContextCancelled verifies that cancelling the +// context while a step is sleeping between retries aborts the retry loop +// and returns the last step error (not a hang or a nil error). +func TestWorkflow_RetryBackoffContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + attempts := 0 + + wf, err := NewWorkflow(). + Step("always-fails", func(ctx context.Context, input any) (any, error) { + attempts++ + if attempts == 1 { + // Cancel during the backoff that follows this failure. + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + } + return nil, errors.New("persistent failure") + }). + WithRetry(RetryConfig{ + MaxRetries: 10, + InitialBackoff: 10 * time.Second, // long enough that cancel wins + MaxBackoff: 10 * time.Second, + BackoffMultiplier: 1.0, + }). + Build() + if err != nil { + t.Fatalf("Build() error: %v", err) + } + + start := time.Now() + _, err = wf.Run(ctx, nil) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected error after cancellation during backoff") + } + // The retry loop returns lastErr when the backoff sleep is interrupted. + if !strings.Contains(err.Error(), "persistent failure") { + t.Errorf("error = %v, want it to wrap the step error", err) + } + if attempts != 1 { + t.Errorf("attempts = %d, want 1 (cancel should stop retries)", attempts) + } + if elapsed > 5*time.Second { + t.Errorf("Run() took %v, should abort promptly on cancellation", elapsed) + } +} + +// TestWorkflow_RetryRespectsStepTimeout verifies the interaction between +// per-step Timeout and RetryConfig: when the step timeout expires during +// retry backoff, the loop stops and reports the last step error. +func TestWorkflow_RetryRespectsStepTimeout(t *testing.T) { + attempts := 0 + + wf, err := NewWorkflow(). + Step("flaky-slow", func(ctx context.Context, input any) (any, error) { + attempts++ + return nil, errors.New("still failing") + }). + WithRetry(RetryConfig{ + MaxRetries: 10, + InitialBackoff: 30 * time.Millisecond, + MaxBackoff: 30 * time.Millisecond, + BackoffMultiplier: 1.0, + }). + WithTimeout(50 * time.Millisecond). + Build() + if err != nil { + t.Fatalf("Build() error: %v", err) + } + + start := time.Now() + _, err = wf.Run(context.Background(), nil) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected error when step timeout expires during retries") + } + if !strings.Contains(err.Error(), "still failing") { + t.Errorf("error = %v, want it to wrap the step error", err) + } + if attempts >= 10 { + t.Errorf("attempts = %d, timeout should have stopped retries early", attempts) + } + if elapsed > 2*time.Second { + t.Errorf("Run() took %v, should stop near the 50ms step timeout", elapsed) + } +} + +// TestWorkflow_RetryTimeoutBeforeFirstAttempt verifies that an +// already-expired step context surfaces the context error when the step +// never ran (lastErr == nil path in executeStep's select). +func TestWorkflow_RetryTimeoutBeforeFirstAttempt(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ran := false + wf, err := NewWorkflow(). + Step("never-runs", func(ctx context.Context, input any) (any, error) { + ran = true + return "ok", nil + }). + WithRetry(RetryConfig{ + MaxRetries: 3, + InitialBackoff: time.Millisecond, + MaxBackoff: time.Millisecond, + BackoffMultiplier: 1.0, + }). + Build() + if err != nil { + t.Fatalf("Build() error: %v", err) + } + + // Run checks ctx before the step, so use executeStep directly to hit the + // retry loop's own cancellation check. + _, stepErr := executeStep(ctx, wf.steps[0], nil) + if !errors.Is(stepErr, context.Canceled) { + t.Errorf("executeStep() error = %v, want context.Canceled", stepErr) + } + if ran { + t.Error("step should not run when context is already cancelled") + } +} + func TestWorkflowBuilder_EmptyWorkflow(t *testing.T) { _, err := NewWorkflow().Build() if err == nil {