Skip to content

Commit bdb17ce

Browse files
committed
Make the todo tool thread safe
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent ac768e6 commit bdb17ce

3 files changed

Lines changed: 67 additions & 18 deletions

File tree

pkg/concurrent/map.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package concurrent
2+
3+
import "sync"
4+
5+
type Map[K comparable, V any] struct {
6+
mu sync.RWMutex
7+
values map[K]V
8+
}
9+
10+
func NewMap[K comparable, V any]() *Map[K, V] {
11+
return &Map[K, V]{
12+
values: make(map[K]V),
13+
}
14+
}
15+
16+
func (m *Map[K, V]) Load(key K) (V, bool) {
17+
m.mu.RLock()
18+
defer m.mu.RUnlock()
19+
20+
val, ok := m.values[key]
21+
return val, ok
22+
}
23+
24+
func (m *Map[K, V]) Store(key K, value V) {
25+
m.mu.Lock()
26+
defer m.mu.Unlock()
27+
28+
m.values[key] = value
29+
}
30+
31+
func (m *Map[K, V]) Length() int {
32+
m.mu.RLock()
33+
defer m.mu.RUnlock()
34+
35+
return len(m.values)
36+
}
37+
38+
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
39+
m.mu.RLock()
40+
defer m.mu.RUnlock()
41+
42+
for k, v := range m.values {
43+
if !f(k, v) {
44+
break
45+
}
46+
}
47+
}

pkg/tools/builtin/todo.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88
"sync"
99

10+
"github.com/docker/cagent/pkg/concurrent"
1011
"github.com/docker/cagent/pkg/tools"
1112
)
1213

@@ -38,15 +39,15 @@ type UpdateTodoArgs struct {
3839
}
3940

4041
type todoHandler struct {
41-
todos map[string]Todo
42+
todos *concurrent.Map[string, Todo]
4243
}
4344

4445
var NewSharedTodoTool = sync.OnceValue(NewTodoTool)
4546

4647
func NewTodoTool() *TodoTool {
4748
return &TodoTool{
4849
handler: &todoHandler{
49-
todos: make(map[string]Todo),
50+
todos: concurrent.NewMap[string, Todo](),
5051
},
5152
}
5253
}
@@ -78,12 +79,12 @@ func (h *todoHandler) createTodo(_ context.Context, toolCall tools.ToolCall) (*t
7879
return nil, fmt.Errorf("invalid arguments: %w", err)
7980
}
8081

81-
id := fmt.Sprintf("todo_%d", len(h.todos)+1)
82-
h.todos[id] = Todo{
82+
id := fmt.Sprintf("todo_%d", h.todos.Length()+1)
83+
h.todos.Store(id, Todo{
8384
ID: id,
8485
Description: params.Description,
8586
Status: "pending",
86-
}
87+
})
8788

8889
return &tools.ToolCallResult{
8990
Output: fmt.Sprintf("Created todo [%s]: %s", id, params.Description),
@@ -97,14 +98,14 @@ func (h *todoHandler) createTodos(_ context.Context, toolCall tools.ToolCall) (*
9798
}
9899

99100
ids := make([]string, len(params.Descriptions))
100-
start := len(h.todos)
101+
start := h.todos.Length()
101102
for i, desc := range params.Descriptions {
102103
id := fmt.Sprintf("todo_%d", start+i+1)
103-
h.todos[id] = Todo{
104+
h.todos.Store(id, Todo{
104105
ID: id,
105106
Description: desc,
106107
Status: "pending",
107-
}
108+
})
108109
ids[i] = id
109110
}
110111

@@ -127,13 +128,13 @@ func (h *todoHandler) updateTodo(_ context.Context, toolCall tools.ToolCall) (*t
127128
return nil, fmt.Errorf("invalid arguments: %w", err)
128129
}
129130

130-
todo, exists := h.todos[params.ID]
131+
todo, exists := h.todos.Load(params.ID)
131132
if !exists {
132133
return nil, fmt.Errorf("todo [%s] not found", params.ID)
133134
}
134135

135136
todo.Status = params.Status
136-
h.todos[params.ID] = todo
137+
h.todos.Store(params.ID, todo)
137138

138139
return &tools.ToolCallResult{
139140
Output: fmt.Sprintf("Updated todo [%s] to status: [%s]", params.ID, params.Status),
@@ -144,10 +145,11 @@ func (h *todoHandler) listTodos(context.Context, tools.ToolCall) (*tools.ToolCal
144145
var output strings.Builder
145146
output.WriteString("Current todos:\n")
146147

147-
for _, todo := range h.todos {
148+
h.todos.Range(func(_ string, todo Todo) bool {
148149
output.WriteString(fmt.Sprintf("- [%s] %s (Status: %s)\n",
149150
todo.ID, todo.Description, todo.Status))
150-
}
151+
return true
152+
})
151153

152154
return &tools.ToolCallResult{
153155
Output: output.String(),

pkg/tools/builtin/todo_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func TestNewTodoTool(t *testing.T) {
1515

1616
assert.NotNil(t, tool)
1717
assert.NotNil(t, tool.handler)
18-
assert.Empty(t, tool.handler.todos)
18+
assert.Zero(t, tool.handler.todos.Length())
1919
}
2020

2121
func TestTodoTool_Instructions(t *testing.T) {
@@ -152,8 +152,8 @@ func TestTodoTool_CreateTodo(t *testing.T) {
152152
assert.Contains(t, result.Output, "Created todo [todo_1]: Test todo item")
153153

154154
// Verify todo was added to the handler's todos map
155-
assert.Len(t, tool.handler.todos, 1)
156-
todo, exists := tool.handler.todos["todo_1"]
155+
assert.Equal(t, 1, tool.handler.todos.Length())
156+
todo, exists := tool.handler.todos.Load("todo_1")
157157
assert.True(t, exists)
158158
assert.Equal(t, "Test todo item", todo.Description)
159159
assert.Equal(t, "pending", todo.Status)
@@ -198,7 +198,7 @@ func TestTodoTool_CreateTodos(t *testing.T) {
198198
assert.Contains(t, result.Output, "todo_3")
199199

200200
// Verify todos were added to the handler's todos map
201-
assert.Len(t, tool.handler.todos, 3)
201+
assert.Equal(t, 3, tool.handler.todos.Length())
202202

203203
// Create multiple todos
204204
args = CreateTodosArgs{
@@ -222,7 +222,7 @@ func TestTodoTool_CreateTodos(t *testing.T) {
222222
require.NoError(t, err)
223223
assert.Contains(t, result.Output, "Created 1 todos:")
224224
assert.Contains(t, result.Output, "todo_4")
225-
assert.Len(t, tool.handler.todos, 4)
225+
assert.Equal(t, 4, tool.handler.todos.Length())
226226
}
227227

228228
func TestTodoTool_UpdateTodo(t *testing.T) {
@@ -276,7 +276,7 @@ func TestTodoTool_UpdateTodo(t *testing.T) {
276276
assert.Contains(t, result.Output, "Updated todo [todo_1] to status: [completed]")
277277

278278
// Verify todo status was updated
279-
todo, exists := tool.handler.todos["todo_1"]
279+
todo, exists := tool.handler.todos.Load("todo_1")
280280
assert.True(t, exists)
281281
assert.Equal(t, "completed", todo.Status)
282282
}

0 commit comments

Comments
 (0)