Skip to content

Commit 6b030c9

Browse files
committed
Extract scoring
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent 4dd8c2e commit 6b030c9

2 files changed

Lines changed: 85 additions & 79 deletions

File tree

pkg/evaluation/evaluation.go

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"encoding/json"
66
"os"
77
"path/filepath"
8-
"strings"
98

109
"github.com/docker/cagent/pkg/chat"
1110
"github.com/docker/cagent/pkg/runtime"
@@ -56,7 +55,7 @@ func Evaluate(ctx context.Context, t *team.Team, evalsDir string) ([]Result, err
5655
return nil, err
5756
}
5857

59-
score := evaluate(evals[i].GetAllMessages(), actualMessages)
58+
score := score(evals[i].GetAllMessages(), actualMessages)
6059

6160
results = append(results, Result{
6261
Score: score,
@@ -92,80 +91,3 @@ func runLoop(ctx context.Context, rt *runtime.LocalRuntime, eval *session.Sessio
9291

9392
return sess.GetAllMessages(), nil
9493
}
95-
96-
func evaluate(expectedMessages, actualMessages []session.Message) Score {
97-
var expectedToolMessages []session.Message
98-
for i := range expectedMessages {
99-
if len(expectedMessages[i].Message.ToolCalls) != 0 {
100-
expectedToolMessages = append(expectedToolMessages, expectedMessages[i])
101-
}
102-
}
103-
104-
var actualToolMessages []session.Message
105-
for i := range actualMessages {
106-
if len(actualMessages[i].Message.ToolCalls) != 0 {
107-
actualToolMessages = append(actualToolMessages, actualMessages[i])
108-
}
109-
}
110-
111-
toolTrajectoryScore := toolTrajectoryScore(expectedToolMessages, actualToolMessages)
112-
rouge1Score := rouge1(expectedMessages[len(expectedMessages)-1].Message.Content, actualMessages[len(actualMessages)-1].Message.Content)
113-
114-
return Score{
115-
ToolTrajectoryScore: toolTrajectoryScore,
116-
Rouge1Score: rouge1Score,
117-
}
118-
}
119-
120-
// https://medium.com/nlplanet/two-minutes-nlp-learn-the-rouge-metric-by-examples-f179cc285499
121-
func rouge1(expected, actual string) float64 {
122-
expectedWords := strings.Fields(strings.ToLower(expected))
123-
actualWords := strings.Fields(strings.ToLower(actual))
124-
125-
expectedSet := make(map[string]int)
126-
for _, word := range expectedWords {
127-
expectedSet[word]++
128-
}
129-
130-
actualSet := make(map[string]int)
131-
for _, word := range actualWords {
132-
actualSet[word]++
133-
}
134-
135-
overlap := 0
136-
for word, expectedCount := range expectedSet {
137-
if actualCount, exists := actualSet[word]; exists {
138-
if actualCount < expectedCount {
139-
overlap += actualCount
140-
} else {
141-
overlap += expectedCount
142-
}
143-
}
144-
}
145-
146-
precision := float64(overlap) / float64(len(actualWords))
147-
recall := float64(overlap) / float64(len(expectedWords))
148-
149-
if precision+recall == 0 {
150-
return 0.0
151-
}
152-
153-
return 2 * (precision * recall) / (precision + recall)
154-
}
155-
156-
func toolTrajectoryScore(expectedToolMessages, actualToolMessages []session.Message) float64 {
157-
score := 0.0
158-
159-
for i := range expectedToolMessages {
160-
expected := expectedToolMessages[i]
161-
actual := actualToolMessages[i]
162-
163-
for j := range actual.Message.ToolCalls {
164-
if actual.Message.ToolCalls[j].Function.Name == expected.Message.ToolCalls[j].Function.Name {
165-
score += 1.0
166-
}
167-
}
168-
}
169-
170-
return score / float64(len(expectedToolMessages))
171-
}

pkg/evaluation/score.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package evaluation
2+
3+
import (
4+
"strings"
5+
6+
"github.com/docker/cagent/pkg/session"
7+
)
8+
9+
func score(expectedMessages, actualMessages []session.Message) Score {
10+
var expectedToolMessages []session.Message
11+
for i := range expectedMessages {
12+
if len(expectedMessages[i].Message.ToolCalls) != 0 {
13+
expectedToolMessages = append(expectedToolMessages, expectedMessages[i])
14+
}
15+
}
16+
17+
var actualToolMessages []session.Message
18+
for i := range actualMessages {
19+
if len(actualMessages[i].Message.ToolCalls) != 0 {
20+
actualToolMessages = append(actualToolMessages, actualMessages[i])
21+
}
22+
}
23+
24+
toolTrajectoryScore := toolTrajectoryScore(expectedToolMessages, actualToolMessages)
25+
rouge1Score := rouge1(expectedMessages[len(expectedMessages)-1].Message.Content, actualMessages[len(actualMessages)-1].Message.Content)
26+
27+
return Score{
28+
ToolTrajectoryScore: toolTrajectoryScore,
29+
Rouge1Score: rouge1Score,
30+
}
31+
}
32+
33+
// https://medium.com/nlplanet/two-minutes-nlp-learn-the-rouge-metric-by-examples-f179cc285499
34+
func rouge1(expected, actual string) float64 {
35+
expectedWords := strings.Fields(strings.ToLower(expected))
36+
actualWords := strings.Fields(strings.ToLower(actual))
37+
38+
expectedSet := make(map[string]int)
39+
for _, word := range expectedWords {
40+
expectedSet[word]++
41+
}
42+
43+
actualSet := make(map[string]int)
44+
for _, word := range actualWords {
45+
actualSet[word]++
46+
}
47+
48+
overlap := 0
49+
for word, expectedCount := range expectedSet {
50+
if actualCount, exists := actualSet[word]; exists {
51+
if actualCount < expectedCount {
52+
overlap += actualCount
53+
} else {
54+
overlap += expectedCount
55+
}
56+
}
57+
}
58+
59+
precision := float64(overlap) / float64(len(actualWords))
60+
recall := float64(overlap) / float64(len(expectedWords))
61+
62+
if precision+recall == 0 {
63+
return 0.0
64+
}
65+
66+
return 2 * (precision * recall) / (precision + recall)
67+
}
68+
69+
func toolTrajectoryScore(expectedToolMessages, actualToolMessages []session.Message) float64 {
70+
score := 0.0
71+
72+
for i := range expectedToolMessages {
73+
expected := expectedToolMessages[i]
74+
actual := actualToolMessages[i]
75+
76+
for j := range actual.Message.ToolCalls {
77+
if actual.Message.ToolCalls[j].Function.Name == expected.Message.ToolCalls[j].Function.Name {
78+
score += 1.0
79+
}
80+
}
81+
}
82+
83+
return score / float64(len(expectedToolMessages))
84+
}

0 commit comments

Comments
 (0)