Skip to content

Commit 25f6b70

Browse files
committed
Validate llm judge
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent f5d618e commit 25f6b70

4 files changed

Lines changed: 128 additions & 18 deletions

File tree

pkg/evaluation/eval.go

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,20 @@ func (r *Runner) Run(ctx context.Context, ttyOut, out io.Writer, isTTY bool) ([]
117117
return nil, fmt.Errorf("loading evaluations: %w", err)
118118
}
119119

120+
// Check whether any evaluations require relevance checking.
121+
// If so, the judge must be configured and working; validate eagerly
122+
// to fail fast on configuration issues (bad API key, wrong model, etc.)
123+
// instead of silently producing zero-relevance results.
124+
if needsJudge(evals) {
125+
if r.judge == nil {
126+
return nil, errors.New("some evaluations have relevance criteria but no judge model is configured (use --judge-model)")
127+
}
128+
fmt.Fprintln(out, "Validating judge model...")
129+
if err := r.judge.Validate(ctx); err != nil {
130+
return nil, fmt.Errorf("%w", err)
131+
}
132+
}
133+
120134
// Pre-build all unique Docker images in parallel before running evaluations.
121135
// This avoids serialized builds when multiple workers need the same image.
122136
if err := r.preBuildImages(ctx, out, evals); err != nil {
@@ -341,12 +355,12 @@ func (r *Runner) runSingleEval(ctx context.Context, evalSess *InputSession) (Res
341355
if r.judge != nil && len(evals.Relevance) > 0 {
342356
// Use transcript for relevance checking to preserve temporal ordering
343357
transcript := buildTranscript(events)
344-
passed, failed, errs := r.judge.CheckRelevance(ctx, transcript, evals.Relevance)
358+
passed, failed, err := r.judge.CheckRelevance(ctx, transcript, evals.Relevance)
359+
if err != nil {
360+
return result, fmt.Errorf("relevance check failed: %w", err)
361+
}
345362
result.RelevancePassed = float64(passed)
346363
result.FailedRelevance = failed
347-
for _, e := range errs {
348-
slog.Warn("Relevance check error", "title", evalSess.Title, "error", e)
349-
}
350364
}
351365

352366
slog.Debug("Evaluation complete", "title", evalSess.Title, "duration", time.Since(startTime))
@@ -590,6 +604,14 @@ func matchesAnyPattern(name string, patterns []string) bool {
590604
})
591605
}
592606

607+
// needsJudge returns true if any evaluation session has relevance criteria,
608+
// meaning a judge model is required to evaluate them.
609+
func needsJudge(evals []InputSession) bool {
610+
return slices.ContainsFunc(evals, func(s InputSession) bool {
611+
return s.Evals != nil && len(s.Evals.Relevance) > 0
612+
})
613+
}
614+
593615
// createJudgeModel creates a provider.Provider from a model string (format: provider/model).
594616
// Returns nil if judgeModel is empty.
595617
func createJudgeModel(ctx context.Context, judgeModel string, runConfig *config.RuntimeConfig) (provider.Provider, error) {

pkg/evaluation/eval_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212

1313
"github.com/stretchr/testify/assert"
1414
"github.com/stretchr/testify/require"
15+
16+
"github.com/docker/docker-agent/pkg/session"
1517
)
1618

1719
func TestToolCallF1Score(t *testing.T) {
@@ -1009,3 +1011,58 @@ func TestMatchesAnyPattern(t *testing.T) {
10091011
})
10101012
}
10111013
}
1014+
1015+
func TestNeedsJudge(t *testing.T) {
1016+
t.Parallel()
1017+
1018+
tests := []struct {
1019+
name string
1020+
evals []InputSession
1021+
want bool
1022+
}{
1023+
{
1024+
name: "no evals",
1025+
evals: nil,
1026+
want: false,
1027+
},
1028+
{
1029+
name: "evals without relevance criteria",
1030+
evals: []InputSession{
1031+
{Session: &session.Session{Evals: &session.EvalCriteria{Size: "M"}}},
1032+
{Session: &session.Session{Evals: &session.EvalCriteria{}}},
1033+
},
1034+
want: false,
1035+
},
1036+
{
1037+
name: "evals with nil Evals field",
1038+
evals: []InputSession{
1039+
{Session: &session.Session{}},
1040+
},
1041+
want: false,
1042+
},
1043+
{
1044+
name: "some evals with relevance criteria",
1045+
evals: []InputSession{
1046+
{Session: &session.Session{Evals: &session.EvalCriteria{}}},
1047+
{Session: &session.Session{Evals: &session.EvalCriteria{Relevance: []string{"criterion1"}}}},
1048+
},
1049+
want: true,
1050+
},
1051+
{
1052+
name: "all evals with relevance criteria",
1053+
evals: []InputSession{
1054+
{Session: &session.Session{Evals: &session.EvalCriteria{Relevance: []string{"a", "b"}}}},
1055+
{Session: &session.Session{Evals: &session.EvalCriteria{Relevance: []string{"c"}}}},
1056+
},
1057+
want: true,
1058+
},
1059+
}
1060+
1061+
for _, tt := range tests {
1062+
t.Run(tt.name, func(t *testing.T) {
1063+
t.Parallel()
1064+
got := needsJudge(tt.evals)
1065+
assert.Equal(t, tt.want, got)
1066+
})
1067+
}
1068+
}

pkg/evaluation/judge.go

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,41 @@ func NewJudge(model provider.Provider, runConfig *config.RuntimeConfig, concurre
8282
}
8383
}
8484

85+
// Validate performs an end-to-end check of the judge model by sending a
86+
// trivial relevance prompt and verifying the response is valid structured
87+
// JSON. This catches configuration errors (bad API key, unsupported model,
88+
// missing structured-output support, etc.) before running any evaluations,
89+
// allowing the framework to fail fast.
90+
func (j *Judge) Validate(ctx context.Context) error {
91+
const (
92+
testResponse = "The sky is blue."
93+
testCriterion = "The response mentions a color."
94+
)
95+
96+
passed, _, err := j.checkSingle(ctx, testResponse, testCriterion)
97+
if err != nil {
98+
return fmt.Errorf("judge model validation failed: %w", err)
99+
}
100+
101+
if !passed {
102+
return errors.New("judge model validation failed: expected the test criterion to pass but the judge returned 'fail'")
103+
}
104+
105+
return nil
106+
}
107+
85108
// RelevanceResult contains the result of a single relevance check.
86109
type RelevanceResult struct {
87110
Criterion string `json:"criterion"`
88111
Reason string `json:"reason"`
89112
}
90113

91114
// CheckRelevance runs all relevance checks concurrently with the configured concurrency.
92-
// It returns the number of passed checks, a slice of failed results with reasons, and any errors encountered.
93-
func (j *Judge) CheckRelevance(ctx context.Context, response string, criteria []string) (passed int, failed []RelevanceResult, errs []string) {
115+
// It returns the number of passed checks, a slice of failed results with reasons, and an error
116+
// if any check encountered an error (e.g. judge model misconfiguration). Errors cause a hard
117+
// failure so that configuration issues are surfaced immediately rather than silently producing
118+
// zero-relevance results.
119+
func (j *Judge) CheckRelevance(ctx context.Context, response string, criteria []string) (passed int, failed []RelevanceResult, err error) {
94120
if len(criteria) == 0 {
95121
return 0, nil, nil
96122
}
@@ -122,17 +148,19 @@ func (j *Judge) CheckRelevance(ctx context.Context, response string, criteria []
122148
results[item.index] = result{err: fmt.Errorf("context cancelled: %w", ctx.Err())}
123149
continue
124150
}
125-
pass, reason, err := j.checkSingle(ctx, response, item.criterion)
126-
results[item.index] = result{passed: pass, reason: reason, err: err}
151+
pass, reason, checkErr := j.checkSingle(ctx, response, item.criterion)
152+
results[item.index] = result{passed: pass, reason: reason, err: checkErr}
127153
}
128154
})
129155
}
130156
wg.Wait()
131157

132-
// Aggregate results
158+
// Aggregate results. Any error is fatal — return it immediately so the
159+
// caller can fail fast on judge misconfiguration.
160+
var errs []error
133161
for i, r := range results {
134162
if r.err != nil {
135-
errs = append(errs, fmt.Sprintf("error checking %q: %v", criteria[i], r.err))
163+
errs = append(errs, fmt.Errorf("checking %q: %w", criteria[i], r.err))
136164
continue
137165
}
138166
if r.passed {
@@ -145,7 +173,11 @@ func (j *Judge) CheckRelevance(ctx context.Context, response string, criteria []
145173
}
146174
}
147175

148-
return passed, failed, errs
176+
if len(errs) > 0 {
177+
return passed, failed, errors.Join(errs...)
178+
}
179+
180+
return passed, failed, nil
149181
}
150182

151183
// getOrCreateJudgeWithSchema returns a provider pre-configured with structured output.

pkg/evaluation/judge_test.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
89
)
910

1011
func TestNewJudge(t *testing.T) {
@@ -46,11 +47,11 @@ func TestJudge_CheckRelevance_EmptyCriteria(t *testing.T) {
4647
t.Parallel()
4748

4849
judge := NewJudge(nil, nil, 1)
49-
passed, failed, errs := judge.CheckRelevance(t.Context(), "some response", nil)
50+
passed, failed, err := judge.CheckRelevance(t.Context(), "some response", nil)
5051

5152
assert.Equal(t, 0, passed)
5253
assert.Empty(t, failed)
53-
assert.Empty(t, errs)
54+
assert.NoError(t, err)
5455
}
5556

5657
func TestJudge_CheckRelevance_ContextCanceled(t *testing.T) {
@@ -62,13 +63,11 @@ func TestJudge_CheckRelevance_ContextCanceled(t *testing.T) {
6263
cancel() // Cancel immediately
6364

6465
criteria := []string{"criterion1", "criterion2", "criterion3"}
65-
passed, failed, errs := judge.CheckRelevance(ctx, "some response", criteria)
66+
passed, failed, err := judge.CheckRelevance(ctx, "some response", criteria)
6667

6768
// All should have errors due to context cancellation
6869
assert.Equal(t, 0, passed)
6970
assert.Empty(t, failed)
70-
assert.Len(t, errs, 3)
71-
for _, err := range errs {
72-
assert.Contains(t, err, "context cancelled")
73-
}
71+
require.Error(t, err)
72+
assert.Contains(t, err.Error(), "context cancelled")
7473
}

0 commit comments

Comments
 (0)