Skip to content

Commit 5a41b0e

Browse files
authored
Merge pull request #1091 from stanislavHamara/cagent-record-mode
Add --record flag to record AI API interactions
2 parents d6af657 + ac83ac8 commit 5a41b0e

8 files changed

Lines changed: 260 additions & 21 deletions

File tree

cmd/root/api.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type apiFlags struct {
2121
sessionDB string
2222
pullIntervalMins int
2323
fakeResponses string
24+
recordPath string
2425
runConfig config.RuntimeConfig
2526
}
2627

@@ -40,6 +41,8 @@ func newAPICmd() *cobra.Command {
4041
cmd.PersistentFlags().StringVarP(&flags.sessionDB, "session-db", "s", "session.db", "Path to the session database")
4142
cmd.PersistentFlags().IntVar(&flags.pullIntervalMins, "pull-interval", 0, "Auto-pull OCI reference every N minutes (0 = disabled)")
4243
cmd.PersistentFlags().StringVar(&flags.fakeResponses, "fake", "", "Replay AI responses from cassette file (for testing)")
44+
cmd.PersistentFlags().StringVar(&flags.recordPath, "record", "", "Record AI API interactions to cassette file")
45+
cmd.MarkFlagsMutuallyExclusive("fake", "record")
4346
addRuntimeConfigFlags(cmd, &flags.runConfig)
4447

4548
return cmd
@@ -71,6 +74,13 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {
7174
slog.Info("Fake mode enabled", "cassette", f.fakeResponses, "proxy", proxyURL)
7275
}
7376

77+
// Start recording proxy if --record is specified
78+
if _, cleanup, err := setupRecordingProxy(f.recordPath, &f.runConfig); err != nil {
79+
return err
80+
} else if cleanup != nil {
81+
defer cleanup()
82+
}
83+
7484
if f.pullIntervalMins > 0 && !config.IsOCIReference(agentsPath) {
7585
return fmt.Errorf("--pull-interval flag can only be used with OCI references, not local files")
7686
}

cmd/root/exec.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ func newExecCmd() *cobra.Command {
1717
Example: ` cagent exec ./agent.yaml
1818
cagent exec ./team.yaml --agent root
1919
cagent exec ./echo.yaml "INSTRUCTIONS"
20-
echo "INSTRUCTIONS" | cagent exec ./echo.yaml -`,
20+
echo "INSTRUCTIONS" | cagent exec ./echo.yaml -
21+
cagent exec ./agent.yaml "question" --record # Records to auto-generated file`,
2122
GroupID: "core",
2223
Args: cobra.RangeArgs(1, 2),
2324
RunE: flags.runExecCommand,

cmd/root/record.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package root
2+
3+
import (
4+
"fmt"
5+
"log/slog"
6+
"strings"
7+
"time"
8+
9+
"github.com/docker/cagent/pkg/config"
10+
"github.com/docker/cagent/pkg/fake"
11+
)
12+
13+
// setupRecordingProxy starts a recording proxy if recordPath is non-empty.
14+
// It handles auto-generating a filename when recordPath is "true" (from NoOptDefVal),
15+
// and normalizes the path by stripping any .yaml suffix.
16+
// Returns the cassette path (with .yaml extension) and a cleanup function.
17+
// The cleanup function must be called when done (typically via defer).
18+
func setupRecordingProxy(recordPath string, runConfig *config.RuntimeConfig) (cassettePath string, cleanup func(), err error) {
19+
if recordPath == "" {
20+
return "", func() {}, nil
21+
}
22+
23+
// Handle auto-generated filename (from NoOptDefVal)
24+
if recordPath == "true" {
25+
recordPath = fmt.Sprintf("cagent-recording-%d", time.Now().Unix())
26+
} else {
27+
recordPath = strings.TrimSuffix(recordPath, ".yaml")
28+
}
29+
30+
proxyURL, cleanupFn, err := fake.StartRecordingProxy(recordPath)
31+
if err != nil {
32+
return "", nil, fmt.Errorf("failed to start recording proxy: %w", err)
33+
}
34+
35+
runConfig.ModelsGateway = proxyURL
36+
cassettePath = recordPath + ".yaml"
37+
38+
slog.Info("Recording mode enabled", "cassette", cassettePath, "proxy", proxyURL)
39+
40+
return cassettePath, func() {
41+
if err := cleanupFn(); err != nil {
42+
slog.Error("Failed to cleanup recording proxy", "error", err)
43+
}
44+
}, nil
45+
}

cmd/root/record_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package root
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"strings"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/docker/cagent/pkg/config"
13+
)
14+
15+
func TestSetupRecordingProxy_EmptyPath(t *testing.T) {
16+
var runConfig config.RuntimeConfig
17+
18+
cassettePath, cleanup, err := setupRecordingProxy("", &runConfig)
19+
20+
require.NoError(t, err)
21+
assert.Empty(t, cassettePath)
22+
assert.NotNil(t, cleanup)
23+
assert.Empty(t, runConfig.ModelsGateway, "ModelsGateway should not be set")
24+
25+
cleanup()
26+
}
27+
28+
func TestSetupRecordingProxy_AutoGeneratesFilename(t *testing.T) {
29+
tmpDir := t.TempDir()
30+
originalWd, err := os.Getwd()
31+
require.NoError(t, err)
32+
require.NoError(t, os.Chdir(tmpDir))
33+
t.Cleanup(func() { _ = os.Chdir(originalWd) })
34+
35+
var runConfig config.RuntimeConfig
36+
37+
cassettePath, cleanup, err := setupRecordingProxy("true", &runConfig)
38+
require.NoError(t, err)
39+
defer cleanup()
40+
41+
assert.True(t, strings.HasPrefix(cassettePath, "cagent-recording-"), "should have auto-generated prefix")
42+
assert.True(t, strings.HasSuffix(cassettePath, ".yaml"), "should have .yaml suffix")
43+
assert.NotEmpty(t, runConfig.ModelsGateway, "ModelsGateway should be set")
44+
}
45+
46+
func TestSetupRecordingProxy_CreatesProxy(t *testing.T) {
47+
tmpDir := t.TempDir()
48+
cassettePath := filepath.Join(tmpDir, "test-recording")
49+
50+
var runConfig config.RuntimeConfig
51+
52+
resultPath, cleanup, err := setupRecordingProxy(cassettePath, &runConfig)
53+
require.NoError(t, err)
54+
defer cleanup()
55+
56+
assert.Equal(t, cassettePath+".yaml", resultPath)
57+
assert.True(t, strings.HasPrefix(runConfig.ModelsGateway, "http://"), "ModelsGateway should be HTTP URL")
58+
}

cmd/root/run.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ type runExecFlags struct {
3232
runConfig config.RuntimeConfig
3333
sessionDB string
3434

35+
// Shared between run and exec
36+
recordPath string
37+
3538
// Exec only
3639
hideToolCalls bool
3740
outputJSON bool
@@ -48,7 +51,8 @@ func newRunCmd() *cobra.Command {
4851
cagent run ./team.yaml --agent root
4952
cagent run # built-in default agent
5053
cagent run ./echo.yaml "INSTRUCTIONS"
51-
echo "INSTRUCTIONS" | cagent run ./echo.yaml -`,
54+
echo "INSTRUCTIONS" | cagent run ./echo.yaml -
55+
cagent run ./agent.yaml --record # Records session to auto-generated file`,
5256
GroupID: "core",
5357
Args: cobra.RangeArgs(0, 2),
5458
RunE: flags.runRunCommand,
@@ -68,6 +72,8 @@ func addRunOrExecFlags(cmd *cobra.Command, flags *runExecFlags) {
6872
cmd.PersistentFlags().BoolVar(&flags.dryRun, "dry-run", false, "Initialize the agent without executing anything")
6973
cmd.PersistentFlags().StringVar(&flags.remoteAddress, "remote", "", "Use remote runtime with specified address")
7074
cmd.PersistentFlags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database")
75+
cmd.PersistentFlags().StringVar(&flags.recordPath, "record", "", "Record AI API interactions to cassette file (auto-generates filename if empty)")
76+
cmd.PersistentFlags().Lookup("record").NoOptDefVal = "true"
7177
}
7278

7379
func (f *runExecFlags) runRunCommand(cmd *cobra.Command, args []string) error {
@@ -83,6 +89,14 @@ func (f *runExecFlags) runRunCommand(cmd *cobra.Command, args []string) error {
8389
func (f *runExecFlags) runOrExec(ctx context.Context, out *cli.Printer, args []string, tui bool) error {
8490
slog.Debug("Starting agent", "agent", f.agentName)
8591

92+
// Record AI API interactions to a cassette file if --record flag is specified.
93+
if cassettePath, cleanup, err := setupRecordingProxy(f.recordPath, &f.runConfig); err != nil {
94+
return err
95+
} else if cassettePath != "" {
96+
defer cleanup()
97+
out.Println("Recording mode enabled, cassette: " + cassettePath)
98+
}
99+
86100
var agentFileName string
87101
if len(args) > 0 {
88102
agentFileName = args[0]

e2e/proxy_test.go

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ package e2e_test
22

33
import (
44
"context"
5-
"net/http"
65
"net/http/httptest"
7-
"os"
86
"path/filepath"
97
"testing"
108

@@ -26,27 +24,11 @@ func startRecordingAIProxy(t *testing.T) (*httptest.Server, *config.RuntimeConfi
2624
require.NoError(t, err)
2725
})
2826

29-
// Header updater that adds real API keys for recording
30-
headerUpdater := func(host string, req *http.Request) {
31-
switch host {
32-
case "https://api.openai.com/v1":
33-
req.Header.Set("Authorization", "Bearer "+os.Getenv("OPENAI_API_KEY"))
34-
case "https://api.anthropic.com":
35-
req.Header.Del("Authorization")
36-
req.Header.Set("X-Api-Key", os.Getenv("ANTHROPIC_API_KEY"))
37-
case "https://generativelanguage.googleapis.com":
38-
req.Header.Del("Authorization")
39-
req.Header.Set("X-Goog-Api-Key", os.Getenv("GOOGLE_API_KEY"))
40-
case "https://api.mistral.ai/v1":
41-
req.Header.Set("Authorization", "Bearer "+os.Getenv("MISTRAL_API_KEY"))
42-
}
43-
}
44-
4527
proxyURL, cleanup, err := fake.StartProxyWithOptions(
4628
cassettePath,
4729
recorder.ModeRecordOnce,
4830
matcher,
49-
headerUpdater,
31+
fake.APIKeyHeaderUpdater,
5032
)
5133
require.NoError(t, err)
5234

pkg/fake/proxy.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"maps"
1111
"net/http"
1212
"net/http/httptest"
13+
"os"
1314
"regexp"
1415
"strings"
1516

@@ -24,6 +25,30 @@ func StartProxy(cassettePath string) (string, func() error, error) {
2425
return StartProxyWithOptions(cassettePath, recorder.ModeReplayOnly, nil, nil)
2526
}
2627

28+
// StartRecordingProxy starts a proxy that records AI API interactions to a cassette file.
29+
// It injects API keys from environment variables for the actual API calls.
30+
// The recorded cassette can later be replayed using StartProxy.
31+
func StartRecordingProxy(cassettePath string) (string, func() error, error) {
32+
return StartProxyWithOptions(cassettePath, recorder.ModeRecordOnce, nil, APIKeyHeaderUpdater)
33+
}
34+
35+
// APIKeyHeaderUpdater injects API keys from environment variables into request headers.
36+
// This is used when recording API interactions to ensure real API calls succeed.
37+
func APIKeyHeaderUpdater(host string, req *http.Request) {
38+
switch host {
39+
case "https://api.openai.com/v1":
40+
req.Header.Set("Authorization", "Bearer "+os.Getenv("OPENAI_API_KEY"))
41+
case "https://api.anthropic.com":
42+
req.Header.Del("Authorization")
43+
req.Header.Set("X-Api-Key", os.Getenv("ANTHROPIC_API_KEY"))
44+
case "https://generativelanguage.googleapis.com":
45+
req.Header.Del("Authorization")
46+
req.Header.Set("X-Goog-Api-Key", os.Getenv("GOOGLE_API_KEY"))
47+
case "https://api.mistral.ai/v1":
48+
req.Header.Set("Authorization", "Bearer "+os.Getenv("MISTRAL_API_KEY"))
49+
}
50+
}
51+
2752
// StartProxyWithOptions starts an internal HTTP proxy with configurable options.
2853
// - mode: recorder mode (ModeReplayOnly, ModeRecordOnce, etc.)
2954
// - matcher: custom matcher function (nil uses default CustomMatcher)

pkg/fake/proxy_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package fake
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestAPIKeyHeaderUpdater(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
host string
15+
envKey string
16+
envValue string
17+
expectedHeader string
18+
expectedValue string
19+
}{
20+
{
21+
name: "OpenAI",
22+
host: "https://api.openai.com/v1",
23+
envKey: "OPENAI_API_KEY",
24+
envValue: "test-openai-key",
25+
expectedHeader: "Authorization",
26+
expectedValue: "Bearer test-openai-key",
27+
},
28+
{
29+
name: "Anthropic",
30+
host: "https://api.anthropic.com",
31+
envKey: "ANTHROPIC_API_KEY",
32+
envValue: "test-anthropic-key",
33+
expectedHeader: "X-Api-Key",
34+
expectedValue: "test-anthropic-key",
35+
},
36+
{
37+
name: "Google",
38+
host: "https://generativelanguage.googleapis.com",
39+
envKey: "GOOGLE_API_KEY",
40+
envValue: "test-google-key",
41+
expectedHeader: "X-Goog-Api-Key",
42+
expectedValue: "test-google-key",
43+
},
44+
{
45+
name: "Mistral",
46+
host: "https://api.mistral.ai/v1",
47+
envKey: "MISTRAL_API_KEY",
48+
envValue: "test-mistral-key",
49+
expectedHeader: "Authorization",
50+
expectedValue: "Bearer test-mistral-key",
51+
},
52+
}
53+
54+
for _, tt := range tests {
55+
t.Run(tt.name, func(t *testing.T) {
56+
t.Setenv(tt.envKey, tt.envValue)
57+
58+
req, err := http.NewRequest(http.MethodPost, "https://example.com", http.NoBody)
59+
require.NoError(t, err)
60+
61+
APIKeyHeaderUpdater(tt.host, req)
62+
63+
assert.Equal(t, tt.expectedValue, req.Header.Get(tt.expectedHeader))
64+
})
65+
}
66+
}
67+
68+
func TestAPIKeyHeaderUpdater_UnknownHost(t *testing.T) {
69+
req, err := http.NewRequest(http.MethodPost, "https://example.com", http.NoBody)
70+
require.NoError(t, err)
71+
72+
APIKeyHeaderUpdater("https://unknown.host.com", req)
73+
74+
assert.Empty(t, req.Header.Get("Authorization"))
75+
assert.Empty(t, req.Header.Get("X-Api-Key"))
76+
}
77+
78+
func TestTargetURLForHost(t *testing.T) {
79+
t.Parallel()
80+
81+
tests := []struct {
82+
host string
83+
expected bool
84+
}{
85+
{"https://api.openai.com/v1", true},
86+
{"https://api.anthropic.com", true},
87+
{"https://generativelanguage.googleapis.com", true},
88+
{"https://api.mistral.ai/v1", true},
89+
{"https://unknown.host.com", false},
90+
}
91+
92+
for _, tt := range tests {
93+
t.Run(tt.host, func(t *testing.T) {
94+
t.Parallel()
95+
96+
fn := TargetURLForHost(tt.host)
97+
if tt.expected {
98+
assert.NotNil(t, fn)
99+
} else {
100+
assert.Nil(t, fn)
101+
}
102+
})
103+
}
104+
}

0 commit comments

Comments
 (0)