|
5 | 5 | "encoding/json" |
6 | 6 | "os" |
7 | 7 | "path/filepath" |
8 | | - "strings" |
9 | 8 |
|
10 | 9 | "github.com/docker/cagent/pkg/chat" |
11 | 10 | "github.com/docker/cagent/pkg/runtime" |
@@ -56,7 +55,7 @@ func Evaluate(ctx context.Context, t *team.Team, evalsDir string) ([]Result, err |
56 | 55 | return nil, err |
57 | 56 | } |
58 | 57 |
|
59 | | - score := evaluate(evals[i].GetAllMessages(), actualMessages) |
| 58 | + score := score(evals[i].GetAllMessages(), actualMessages) |
60 | 59 |
|
61 | 60 | results = append(results, Result{ |
62 | 61 | Score: score, |
@@ -92,80 +91,3 @@ func runLoop(ctx context.Context, rt *runtime.LocalRuntime, eval *session.Sessio |
92 | 91 |
|
93 | 92 | return sess.GetAllMessages(), nil |
94 | 93 | } |
95 | | - |
96 | | -func evaluate(expectedMessages, actualMessages []session.Message) Score { |
97 | | - var expectedToolMessages []session.Message |
98 | | - for i := range expectedMessages { |
99 | | - if len(expectedMessages[i].Message.ToolCalls) != 0 { |
100 | | - expectedToolMessages = append(expectedToolMessages, expectedMessages[i]) |
101 | | - } |
102 | | - } |
103 | | - |
104 | | - var actualToolMessages []session.Message |
105 | | - for i := range actualMessages { |
106 | | - if len(actualMessages[i].Message.ToolCalls) != 0 { |
107 | | - actualToolMessages = append(actualToolMessages, actualMessages[i]) |
108 | | - } |
109 | | - } |
110 | | - |
111 | | - toolTrajectoryScore := toolTrajectoryScore(expectedToolMessages, actualToolMessages) |
112 | | - rouge1Score := rouge1(expectedMessages[len(expectedMessages)-1].Message.Content, actualMessages[len(actualMessages)-1].Message.Content) |
113 | | - |
114 | | - return Score{ |
115 | | - ToolTrajectoryScore: toolTrajectoryScore, |
116 | | - Rouge1Score: rouge1Score, |
117 | | - } |
118 | | -} |
119 | | - |
120 | | -// https://medium.com/nlplanet/two-minutes-nlp-learn-the-rouge-metric-by-examples-f179cc285499 |
121 | | -func rouge1(expected, actual string) float64 { |
122 | | - expectedWords := strings.Fields(strings.ToLower(expected)) |
123 | | - actualWords := strings.Fields(strings.ToLower(actual)) |
124 | | - |
125 | | - expectedSet := make(map[string]int) |
126 | | - for _, word := range expectedWords { |
127 | | - expectedSet[word]++ |
128 | | - } |
129 | | - |
130 | | - actualSet := make(map[string]int) |
131 | | - for _, word := range actualWords { |
132 | | - actualSet[word]++ |
133 | | - } |
134 | | - |
135 | | - overlap := 0 |
136 | | - for word, expectedCount := range expectedSet { |
137 | | - if actualCount, exists := actualSet[word]; exists { |
138 | | - if actualCount < expectedCount { |
139 | | - overlap += actualCount |
140 | | - } else { |
141 | | - overlap += expectedCount |
142 | | - } |
143 | | - } |
144 | | - } |
145 | | - |
146 | | - precision := float64(overlap) / float64(len(actualWords)) |
147 | | - recall := float64(overlap) / float64(len(expectedWords)) |
148 | | - |
149 | | - if precision+recall == 0 { |
150 | | - return 0.0 |
151 | | - } |
152 | | - |
153 | | - return 2 * (precision * recall) / (precision + recall) |
154 | | -} |
155 | | - |
156 | | -func toolTrajectoryScore(expectedToolMessages, actualToolMessages []session.Message) float64 { |
157 | | - score := 0.0 |
158 | | - |
159 | | - for i := range expectedToolMessages { |
160 | | - expected := expectedToolMessages[i] |
161 | | - actual := actualToolMessages[i] |
162 | | - |
163 | | - for j := range actual.Message.ToolCalls { |
164 | | - if actual.Message.ToolCalls[j].Function.Name == expected.Message.ToolCalls[j].Function.Name { |
165 | | - score += 1.0 |
166 | | - } |
167 | | - } |
168 | | - } |
169 | | - |
170 | | - return score / float64(len(expectedToolMessages)) |
171 | | -} |
0 commit comments