Skip to content

Commit 52a5bb4

Browse files
authored
Merge pull request #941 from krissetto/escape-js-template-literals
Escape js template literals
2 parents a91cff5 + c51dcb9 commit 52a5bb4

2 files changed

Lines changed: 84 additions & 3 deletions

File tree

pkg/js/expand.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,24 @@ package js
33
import (
44
"context"
55
"fmt"
6+
"strings"
67
"sync"
78

89
"github.com/dop251/goja"
910

1011
"github.com/docker/cagent/pkg/environment"
1112
)
1213

14+
// escapeForTemplateLiteral escapes characters that have special meaning in
15+
// JavaScript template literals
16+
func escapeForTemplateLiteral(s string) string {
17+
// Escape backticks so they don't terminate the template literal.
18+
// Also escape backslashes that precede backticks to avoid double-escaping issues.
19+
s = strings.ReplaceAll(s, "\\`", "\\\\`") // First escape already-escaped backticks
20+
s = strings.ReplaceAll(s, "`", "\\`") // Then escape remaining backticks
21+
return s
22+
}
23+
1324
type jsEnv func(string) goja.Value
1425

1526
func (e jsEnv) Get(k string) goja.Value { return e(k) }
@@ -50,7 +61,7 @@ func (exp *Expander) ExpandMap(ctx context.Context, kv map[string]string) map[st
5061
vm := exp.jsRuntime(ctx)
5162

5263
for k, v := range kv {
53-
result, err := vm.RunString("`" + v + "`")
64+
result, err := vm.RunString("`" + escapeForTemplateLiteral(v) + "`")
5465
if err != nil {
5566
expanded[k] = v
5667
continue
@@ -65,7 +76,7 @@ func (exp *Expander) ExpandMap(ctx context.Context, kv map[string]string) map[st
6576
func (exp *Expander) Expand(ctx context.Context, text string) string {
6677
vm := exp.jsRuntime(ctx)
6778

68-
result, err := vm.RunString("`" + text + "`")
79+
result, err := vm.RunString("`" + escapeForTemplateLiteral(text) + "`")
6980
if err != nil {
7081
return text
7182
}
@@ -80,7 +91,7 @@ func ExpandString(ctx context.Context, str string, values map[string]string) (st
8091
_ = vm.Set(k, v)
8192
}
8293

83-
expanded, err := vm.RunString("`" + str + "`")
94+
expanded, err := vm.RunString("`" + escapeForTemplateLiteral(str) + "`")
8495
if err != nil {
8596
return "", err
8697
}

pkg/js/expand_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
89
)
910

1011
func TestExpand(t *testing.T) {
@@ -58,6 +59,18 @@ func TestExpand(t *testing.T) {
5859
envVars: map[string]string{},
5960
expected: "UNKNOWN",
6061
},
62+
{
63+
name: "backticks in template (markdown code fence)",
64+
commands: "Here is code:\n```\n${env.CODE}\n```\nEnd.",
65+
envVars: map[string]string{"CODE": "fmt.Println()"},
66+
expected: "Here is code:\n```\nfmt.Println()\n```\nEnd.",
67+
},
68+
{
69+
name: "multiple backticks",
70+
commands: "Use `inline` and ```block``` code",
71+
envVars: map[string]string{},
72+
expected: "Use `inline` and ```block``` code",
73+
},
6174
}
6275

6376
for _, tt := range tests {
@@ -128,6 +141,63 @@ func TestExpandMap_Empty(t *testing.T) {
128141
assert.Empty(t, result)
129142
}
130143

144+
func TestExpandString(t *testing.T) {
145+
t.Parallel()
146+
147+
tests := []struct {
148+
name string
149+
template string
150+
values map[string]string
151+
expected string
152+
wantErr bool
153+
}{
154+
{
155+
name: "simple substitution",
156+
template: "Hello ${name}!",
157+
values: map[string]string{"name": "World"},
158+
expected: "Hello World!",
159+
},
160+
{
161+
name: "multiple values",
162+
template: "File: ${path} (chunk ${index})",
163+
values: map[string]string{"path": "/foo/bar.go", "index": "0"},
164+
expected: "File: /foo/bar.go (chunk 0)",
165+
},
166+
{
167+
name: "backticks in template are preserved",
168+
template: "Code:\n```\n${content}\n```",
169+
values: map[string]string{"content": "func main() {}"},
170+
expected: "Code:\n```\nfunc main() {}\n```",
171+
},
172+
{
173+
name: "backticks in value are preserved",
174+
template: "The code is: ${code}",
175+
values: map[string]string{"code": "use `fmt.Println()`"},
176+
expected: "The code is: use `fmt.Println()`",
177+
},
178+
{
179+
name: "semantic prompt with code fence",
180+
template: "Summarize:\n```\n${content}\n```\nBe concise.",
181+
values: map[string]string{"content": "package main\n\nfunc main() {\n\tfmt.Println(`hello`)\n}"},
182+
expected: "Summarize:\n```\npackage main\n\nfunc main() {\n\tfmt.Println(`hello`)\n}\n```\nBe concise.",
183+
},
184+
}
185+
186+
for _, tt := range tests {
187+
t.Run(tt.name, func(t *testing.T) {
188+
t.Parallel()
189+
190+
result, err := ExpandString(t.Context(), tt.template, tt.values)
191+
if tt.wantErr {
192+
require.Error(t, err)
193+
return
194+
}
195+
require.NoError(t, err)
196+
assert.Equal(t, tt.expected, result)
197+
})
198+
}
199+
}
200+
131201
type testEnvProvider map[string]string
132202

133203
func (p *testEnvProvider) Get(_ context.Context, name string) string {

0 commit comments

Comments
 (0)