Skip to content

Commit bd2a487

Browse files
committed
Add the criteria to the session
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent bb8b0d1 commit bd2a487

5 files changed

Lines changed: 59 additions & 35 deletions

File tree

pkg/evaluation/eval.go

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/docker/cagent/pkg/environment"
2525
"github.com/docker/cagent/pkg/model/provider"
2626
"github.com/docker/cagent/pkg/model/provider/options"
27+
"github.com/docker/cagent/pkg/session"
2728
)
2829

2930
// Runner runs evaluations against an agent.
@@ -98,7 +99,7 @@ func Evaluate(ctx context.Context, ttyOut, out io.Writer, isTTY bool, runName st
9899
// workItem represents a single evaluation to be processed.
99100
type workItem struct {
100101
index int
101-
eval *EvalSession
102+
eval *InputSession
102103
}
103104

104105
// Run executes all evaluations concurrently and returns results.
@@ -163,13 +164,13 @@ func (r *Runner) Run(ctx context.Context, ttyOut, out io.Writer, isTTY bool) ([]
163164
return results, nil
164165
}
165166

166-
func (r *Runner) loadEvalSessions(ctx context.Context) ([]EvalSession, error) {
167+
func (r *Runner) loadEvalSessions(ctx context.Context) ([]InputSession, error) {
167168
entries, err := os.ReadDir(r.EvalsDir)
168169
if err != nil {
169170
return nil, err
170171
}
171172

172-
var evals []EvalSession
173+
var evals []InputSession
173174
for _, entry := range entries {
174175
if ctx.Err() != nil {
175176
return nil, ctx.Err()
@@ -190,22 +191,19 @@ func (r *Runner) loadEvalSessions(ctx context.Context) ([]EvalSession, error) {
190191
return nil, err
191192
}
192193

193-
var evalSess EvalSession
194+
var evalSess session.Session
194195
if err := json.Unmarshal(data, &evalSess); err != nil {
195196
return nil, err
196197
}
197198

198-
evalSess.SourcePath = filepath.Join(r.EvalsDir, fileName)
199-
200-
if evalSess.Title == "" {
201-
evalSess.Title = strings.TrimSuffix(fileName, ".json")
202-
}
203-
204-
evals = append(evals, evalSess)
199+
evals = append(evals, InputSession{
200+
Session: &evalSess,
201+
SourcePath: filepath.Join(r.EvalsDir, fileName),
202+
})
205203
}
206204

207205
// Sort by duration (longest first) to avoid long tail
208-
slices.SortFunc(evals, func(a, b EvalSession) int {
206+
slices.SortFunc(evals, func(a, b InputSession) int {
209207
return cmp.Compare(b.Duration(), a.Duration())
210208
})
211209

@@ -214,11 +212,13 @@ func (r *Runner) loadEvalSessions(ctx context.Context) ([]EvalSession, error) {
214212

215213
// preBuildImages pre-builds all unique Docker images needed for the evaluations.
216214
// This is done in parallel to avoid serialized builds during evaluation.
217-
func (r *Runner) preBuildImages(ctx context.Context, out io.Writer, evals []EvalSession) error {
215+
func (r *Runner) preBuildImages(ctx context.Context, out io.Writer, evals []InputSession) error {
218216
// Collect unique working directories
219217
workingDirs := make(map[string]struct{})
220218
for _, eval := range evals {
221-
workingDirs[eval.Evals.WorkingDir] = struct{}{}
219+
if eval.Evals != nil {
220+
workingDirs[eval.Evals.WorkingDir] = struct{}{}
221+
}
222222
}
223223

224224
if len(workingDirs) == 0 {
@@ -278,24 +278,31 @@ func (r *Runner) preBuildImages(ctx context.Context, out io.Writer, evals []Eval
278278
return nil
279279
}
280280

281-
func (r *Runner) runSingleEval(ctx context.Context, evalSess *EvalSession) (Result, error) {
281+
func (r *Runner) runSingleEval(ctx context.Context, evalSess *InputSession) (Result, error) {
282282
startTime := time.Now()
283283
slog.Debug("Starting evaluation", "title", evalSess.Title)
284284

285+
var evals *session.EvalCriteria
286+
if evalSess.Evals != nil {
287+
evals = evalSess.Evals
288+
} else {
289+
evals = &session.EvalCriteria{}
290+
}
291+
285292
result := Result{
286293
InputPath: evalSess.SourcePath,
287294
Title: evalSess.Title,
288-
Question: getFirstUserMessage(&evalSess.Session),
289-
SizeExpected: evalSess.Evals.Size,
290-
RelevanceExpected: float64(len(evalSess.Evals.Relevance)),
295+
Question: getFirstUserMessage(evalSess.Session),
296+
SizeExpected: evals.Size,
297+
RelevanceExpected: float64(len(evals.Relevance)),
291298
}
292299

293300
expectedToolCalls := extractToolCalls(evalSess.Messages)
294301
if len(expectedToolCalls) > 0 {
295302
result.ToolCallsExpected = 1.0
296303
}
297304

298-
workingDir := evalSess.Evals.WorkingDir
305+
workingDir := evals.WorkingDir
299306

300307
imageID, err := r.getOrBuildImage(ctx, workingDir)
301308
if err != nil {
@@ -316,15 +323,16 @@ func (r *Runner) runSingleEval(ctx context.Context, evalSess *EvalSession) (Resu
316323

317324
// Build session from events for database storage
318325
result.Session = SessionFromEvents(events, evalSess.Title, result.Question)
326+
result.Session.Evals = evals
319327

320328
if len(expectedToolCalls) > 0 || len(actualToolCalls) > 0 {
321329
result.ToolCallsScore = toolCallF1Score(expectedToolCalls, actualToolCalls)
322330
}
323331

324332
result.HandoffsMatch = countHandoffs(expectedToolCalls) == countHandoffs(actualToolCalls)
325333

326-
if r.judge != nil && len(evalSess.Evals.Relevance) > 0 {
327-
passed, failed, errs := r.judge.CheckRelevance(ctx, result.Response, evalSess.Evals.Relevance)
334+
if r.judge != nil && len(evals.Relevance) > 0 {
335+
passed, failed, errs := r.judge.CheckRelevance(ctx, result.Response, evals.Relevance)
328336
result.RelevancePassed = float64(passed)
329337
result.FailedRelevance = failed
330338
for _, e := range errs {

pkg/evaluation/save.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func SaveRunJSON(run *EvalRun, outputDir string) (string, error) {
308308
}
309309

310310
// SaveRunSessionsJSON saves all eval sessions to a single JSON file.
311-
// Each session is saved in the same format as /eval produces (session.Session).
311+
// Each session includes its eval criteria in the "evals" field.
312312
// This complements SaveRunSessions which saves to SQLite, providing a
313313
// human-readable format for inspection.
314314
func SaveRunSessionsJSON(run *EvalRun, outputDir string) (string, error) {
@@ -336,6 +336,11 @@ func Save(sess *session.Session, filename string) (string, error) {
336336
evalFile = filepath.Join("evals", fmt.Sprintf("%s_%d.json", baseName, number))
337337
}
338338

339+
// Ensure session has empty eval criteria for easier discovery
340+
if sess.Evals == nil {
341+
sess.Evals = &session.EvalCriteria{Relevance: []string{}}
342+
}
343+
339344
return saveJSON(sess, evalFile)
340345
}
341346

pkg/evaluation/save_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ func TestSaveWithCustomFilename(t *testing.T) {
2929
require.Equal(t, filepath.Join("evals", "my-custom-eval.json"), evalFile)
3030
require.FileExists(t, evalFile)
3131

32+
// Verify the saved file contains the evals field
33+
data, err := os.ReadFile(evalFile)
34+
require.NoError(t, err)
35+
var savedSession session.Session
36+
err = json.Unmarshal(data, &savedSession)
37+
require.NoError(t, err)
38+
assert.NotNil(t, savedSession.Evals)
39+
assert.Empty(t, savedSession.Evals.Relevance)
40+
3241
// Test 2: Save without filename (should use session ID)
3342
evalFile2, err := Save(sess, "")
3443
require.NoError(t, err)
@@ -131,7 +140,7 @@ func TestSaveRunSessionsJSON(t *testing.T) {
131140
sess2.OutputTokens = 30
132141
sess2.Cost = 0.005
133142

134-
// Create an eval run with sessions
143+
// Create an eval run with sessions and eval criteria
135144
run := &EvalRun{
136145
Name: "test-json-001",
137146
Timestamp: time.Now(),

pkg/evaluation/types.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,10 @@ import (
77
"github.com/docker/cagent/pkg/session"
88
)
99

10-
// EvalCriteria contains the evaluation criteria for a test case.
11-
type EvalCriteria struct {
12-
Relevance []string `json:"relevance,omitempty"` // Statements that should be true about the response
13-
WorkingDir string `json:"working_dir,omitempty"` // Subdirectory under evals/working_dirs/
14-
Size string `json:"size,omitempty"` // Expected response size: S, M, L, XL
15-
}
16-
17-
// EvalSession extends session.Session with evaluation criteria.
18-
type EvalSession struct {
19-
session.Session
20-
Evals EvalCriteria `json:"evals"`
21-
SourcePath string `json:"-"` // Path to the source eval file (not serialized)
10+
// InputSession wraps a session with its source path for evaluation loading.
11+
type InputSession struct {
12+
*session.Session
13+
SourcePath string // Path to the source eval file (not serialized)
2214
}
2315

2416
// Result contains the evaluation results for a single test case.

pkg/session/session.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ type Session struct {
5454
// Title is the title of the session, set by the runtime
5555
Title string `json:"title"`
5656

57+
// Evals contains evaluation criteria for this session (used by eval framework)
58+
Evals *EvalCriteria `json:"evals,omitempty"`
59+
5760
// Messages holds the conversation history (messages and sub-sessions)
5861
Messages []Item `json:"messages"`
5962

@@ -189,6 +192,13 @@ func NewSubSessionItem(subSession *Session) Item {
189192
return Item{SubSession: subSession}
190193
}
191194

195+
// EvalCriteria contains the evaluation criteria for a session.
196+
type EvalCriteria struct {
197+
Relevance []string `json:"relevance"` // Statements that should be true about the response
198+
WorkingDir string `json:"working_dir,omitempty"` // Subdirectory under evals/working_dirs/
199+
Size string `json:"size,omitempty"` // Expected response size: S, M, L, XL
200+
}
201+
192202
// Session helper methods
193203

194204
// AddMessage adds a message to the session

0 commit comments

Comments
 (0)