-
Notifications
You must be signed in to change notification settings - Fork 118
Expand file tree
/
Copy pathconversation.go
More file actions
419 lines (368 loc) · 11.8 KB
/
conversation.go
File metadata and controls
419 lines (368 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
package screentracker
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/coder/agentapi/lib/msgfmt"
"github.com/coder/agentapi/lib/util"
"github.com/danielgtaylor/huma/v2"
"golang.org/x/xerrors"
)
type screenSnapshot struct {
timestamp time.Time
screen string
}
type AgentIO interface {
Write(data []byte) (int, error)
ReadScreen() string
}
type ConversationConfig struct {
AgentType msgfmt.AgentType
AgentIO AgentIO
// GetTime returns the current time
GetTime func() time.Time
// How often to take a snapshot for the stability check
SnapshotInterval time.Duration
// How long the screen should not change to be considered stable
ScreenStabilityLength time.Duration
// Function to format the messages received from the agent
// userInput is the last user message
FormatMessage func(message string, userInput string) string
// SkipWritingMessage skips the writing of a message to the agent.
// This is used in tests
SkipWritingMessage bool
// SkipSendMessageStatusCheck skips the check for whether the message can be sent.
// This is used in tests
SkipSendMessageStatusCheck bool
}
type ConversationRole string
const (
ConversationRoleUser ConversationRole = "user"
ConversationRoleAgent ConversationRole = "agent"
)
var ConversationRoleValues = []ConversationRole{
ConversationRoleUser,
ConversationRoleAgent,
}
func (c ConversationRole) Schema(r huma.Registry) *huma.Schema {
return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues)
}
type ConversationMessage struct {
Id int
Message string
Role ConversationRole
Time time.Time
}
type Conversation struct {
cfg ConversationConfig
// How many stable snapshots are required to consider the screen stable
stableSnapshotsThreshold int
snapshotBuffer *RingBuffer[screenSnapshot]
messages []ConversationMessage
screenBeforeLastUserMessage string
lock sync.Mutex
// InitialPrompt is the initial prompt passed to the agent
InitialPrompt string
// InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents
InitialPromptSent bool
}
type ConversationStatus string
const (
ConversationStatusChanging ConversationStatus = "changing"
ConversationStatusStable ConversationStatus = "stable"
ConversationStatusInitializing ConversationStatus = "initializing"
)
func getStableSnapshotsThreshold(cfg ConversationConfig) int {
length := cfg.ScreenStabilityLength.Milliseconds()
interval := cfg.SnapshotInterval.Milliseconds()
threshold := int(length / interval)
if length%interval != 0 {
threshold++
}
return threshold + 1
}
func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation {
threshold := getStableSnapshotsThreshold(cfg)
c := &Conversation{
cfg: cfg,
stableSnapshotsThreshold: threshold,
snapshotBuffer: NewRingBuffer[screenSnapshot](threshold),
messages: []ConversationMessage{
{
Message: "",
Role: ConversationRoleAgent,
Time: cfg.GetTime(),
},
},
InitialPrompt: initialPrompt,
InitialPromptSent: len(initialPrompt) == 0,
}
return c
}
func (c *Conversation) StartSnapshotLoop(ctx context.Context) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(c.cfg.SnapshotInterval):
// It's important that we hold the lock while reading the screen.
// There's a race condition that occurs without it:
// 1. The screen is read
// 2. Independently, SendMessage is called and takes the lock.
// 3. AddSnapshot is called and waits on the lock.
// 4. SendMessage modifies the terminal state, releases the lock
// 5. AddSnapshot adds a snapshot from a stale screen
c.lock.Lock()
screen := c.cfg.AgentIO.ReadScreen()
c.addSnapshotInner(screen)
c.lock.Unlock()
}
}
}()
}
func FindNewMessage(oldScreen, newScreen string, agentType msgfmt.AgentType) string {
oldLines := strings.Split(oldScreen, "\n")
newLines := strings.Split(newScreen, "\n")
oldLinesMap := make(map[string]bool)
for _, line := range oldLines {
oldLinesMap[line] = true
}
firstNonMatchingLine := len(newLines)
for i, line := range newLines {
if !oldLinesMap[line] {
firstNonMatchingLine = i
break
}
}
newSectionLines := newLines[firstNonMatchingLine:]
// remove leading and trailing lines which are empty or have only whitespace
startLine := 0
endLine := len(newSectionLines) - 1
for i := 0; i < len(newSectionLines); i++ {
if strings.TrimSpace(newSectionLines[i]) != "" {
startLine = i
break
}
}
for i := len(newSectionLines) - 1; i >= 0; i-- {
if strings.TrimSpace(newSectionLines[i]) != "" {
endLine = i
break
}
}
return strings.Join(newSectionLines[startLine:endLine+1], "\n")
}
func (c *Conversation) lastMessage(role ConversationRole) ConversationMessage {
for i := len(c.messages) - 1; i >= 0; i-- {
if c.messages[i].Role == role {
return c.messages[i]
}
}
return ConversationMessage{}
}
// This function assumes that the caller holds the lock
func (c *Conversation) updateLastAgentMessage(screen string, timestamp time.Time) {
agentMessage := FindNewMessage(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType)
lastUserMessage := c.lastMessage(ConversationRoleUser)
if c.cfg.FormatMessage != nil {
agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message)
}
shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser
lastAgentMessage := c.lastMessage(ConversationRoleAgent)
if lastAgentMessage.Message == agentMessage {
return
}
conversationMessage := ConversationMessage{
Message: agentMessage,
Role: ConversationRoleAgent,
Time: timestamp,
}
if shouldCreateNewMessage {
c.messages = append(c.messages, conversationMessage)
} else {
c.messages[len(c.messages)-1] = conversationMessage
}
c.messages[len(c.messages)-1].Id = len(c.messages) - 1
}
// assumes the caller holds the lock
func (c *Conversation) addSnapshotInner(screen string) {
snapshot := screenSnapshot{
timestamp: c.cfg.GetTime(),
screen: screen,
}
c.snapshotBuffer.Add(snapshot)
c.updateLastAgentMessage(screen, snapshot.timestamp)
}
func (c *Conversation) AddSnapshot(screen string) {
c.lock.Lock()
defer c.lock.Unlock()
c.addSnapshotInner(screen)
}
type MessagePart interface {
Do(writer AgentIO) error
String() string
}
type MessagePartText struct {
Content string
Alias string
Hidden bool
}
func (p MessagePartText) Do(writer AgentIO) error {
_, err := writer.Write([]byte(p.Content))
return err
}
func (p MessagePartText) String() string {
if p.Hidden {
return ""
}
if p.Alias != "" {
return p.Alias
}
return p.Content
}
func PartsToString(parts ...MessagePart) string {
var sb strings.Builder
for _, part := range parts {
sb.WriteString(part.String())
}
return sb.String()
}
func ExecuteParts(writer AgentIO, parts ...MessagePart) error {
for _, part := range parts {
if err := part.Do(writer); err != nil {
return xerrors.Errorf("failed to write message part: %w", err)
}
}
return nil
}
func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, messageParts ...MessagePart) error {
if c.cfg.SkipWritingMessage {
return nil
}
screenBeforeMessage := c.cfg.AgentIO.ReadScreen()
if err := ExecuteParts(c.cfg.AgentIO, messageParts...); err != nil {
return xerrors.Errorf("failed to write message: %w", err)
}
// wait for the screen to stabilize after the message is written
if err := util.WaitFor(ctx, util.WaitTimeout{
Timeout: 15 * time.Second,
MinInterval: 50 * time.Millisecond,
InitialWait: true,
}, func() (bool, error) {
screen := c.cfg.AgentIO.ReadScreen()
if screen != screenBeforeMessage {
time.Sleep(1 * time.Second)
newScreen := c.cfg.AgentIO.ReadScreen()
return newScreen == screen, nil
}
return false, nil
}); err != nil {
return xerrors.Errorf("failed to wait for screen to stabilize: %w", err)
}
// wait for the screen to change after the carriage return is written
screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen()
lastCarriageReturnTime := time.Time{}
if err := util.WaitFor(ctx, util.WaitTimeout{
Timeout: 15 * time.Second,
MinInterval: 25 * time.Millisecond,
}, func() (bool, error) {
// we don't want to spam additional carriage returns because the agent may process them
// (aider does this), but we do want to retry sending one if nothing's
// happening for a while
if time.Since(lastCarriageReturnTime) >= 3*time.Second {
lastCarriageReturnTime = time.Now()
if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil {
return false, xerrors.Errorf("failed to write carriage return: %w", err)
}
}
time.Sleep(25 * time.Millisecond)
screen := c.cfg.AgentIO.ReadScreen()
return screen != screenBeforeCarriageReturn, nil
}); err != nil {
return xerrors.Errorf("failed to wait for processing to start: %w", err)
}
return nil
}
var MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace")
var MessageValidationErrorEmpty = xerrors.New("message must not be empty")
var MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input")
func (c *Conversation) SendMessage(messageParts ...MessagePart) error {
c.lock.Lock()
defer c.lock.Unlock()
if !c.cfg.SkipSendMessageStatusCheck && c.statusInner() != ConversationStatusStable {
return MessageValidationErrorChanging
}
message := PartsToString(messageParts...)
if message != msgfmt.TrimWhitespace(message) {
// msgfmt formatting functions assume this
return MessageValidationErrorWhitespace
}
if message == "" {
// writeMessageWithConfirmation requires a non-empty message
return MessageValidationErrorEmpty
}
screenBeforeMessage := c.cfg.AgentIO.ReadScreen()
now := c.cfg.GetTime()
c.updateLastAgentMessage(screenBeforeMessage, now)
if err := c.writeMessageWithConfirmation(context.Background(), messageParts...); err != nil {
return xerrors.Errorf("failed to send message: %w", err)
}
c.screenBeforeLastUserMessage = screenBeforeMessage
c.messages = append(c.messages, ConversationMessage{
Id: len(c.messages),
Message: message,
Role: ConversationRoleUser,
Time: now,
})
return nil
}
// Assumes that the caller holds the lock
func (c *Conversation) statusInner() ConversationStatus {
// sanity checks
if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold {
panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold))
}
if c.stableSnapshotsThreshold == 0 {
panic("stable snapshots threshold is 0. can't check stability")
}
snapshots := c.snapshotBuffer.GetAll()
if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser {
// if the last message is a user message then the snapshot loop hasn't
// been triggered since the last user message, and we should assume
// the screen is changing
return ConversationStatusChanging
}
if len(snapshots) != c.stableSnapshotsThreshold {
return ConversationStatusInitializing
}
for i := 1; i < len(snapshots); i++ {
if snapshots[0].screen != snapshots[i].screen {
return ConversationStatusChanging
}
}
return ConversationStatusStable
}
func (c *Conversation) Status() ConversationStatus {
c.lock.Lock()
defer c.lock.Unlock()
return c.statusInner()
}
func (c *Conversation) Messages() []ConversationMessage {
c.lock.Lock()
defer c.lock.Unlock()
result := make([]ConversationMessage, len(c.messages))
copy(result, c.messages)
return result
}
func (c *Conversation) Screen() string {
c.lock.Lock()
defer c.lock.Unlock()
snapshots := c.snapshotBuffer.GetAll()
if len(snapshots) == 0 {
return ""
}
return snapshots[len(snapshots)-1].screen
}