Skip to content

Commit b3f11ef

Browse files
committed
options: validate filter strings
just on the offchance that someone sticks garbage in the config.
1 parent d0f37dd commit b3f11ef

2 files changed

Lines changed: 186 additions & 6 deletions

File tree

cmd/multibuild/options.go

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,84 @@ func validateTemplate(s string) (outputTemplate, error) {
175175
return outputTemplate(s), nil
176176
}
177177

178+
func validateFilterString(s string) ([]filter, error) {
179+
isAlphaNum := func(b byte) bool {
180+
return (b >= 'a' && b <= 'z') ||
181+
(b >= 'A' && b <= 'Z') ||
182+
(b >= '0' && b <= '9')
183+
}
184+
185+
var out []filter
186+
187+
i := 0
188+
for i < len(s) {
189+
start := i
190+
191+
// parse GOOS
192+
osStart := i
193+
if i < len(s) {
194+
if s[i] == '*' {
195+
i++
196+
} else {
197+
for i < len(s) && isAlphaNum(s[i]) {
198+
i++
199+
}
200+
}
201+
}
202+
if osStart == i {
203+
return nil, fmt.Errorf("at %d: expected GOOS", i)
204+
}
205+
if i >= len(s) || s[i] != '/' {
206+
if i < len(s) {
207+
return nil, fmt.Errorf("at %d: unexpected character: %c", i, s[i])
208+
}
209+
return nil, fmt.Errorf("at %d: expected '/'", i)
210+
}
211+
goos := s[osStart:i]
212+
i++ // skip '/'
213+
214+
// parse GOARCH
215+
archStart := i
216+
if i < len(s) {
217+
if s[i] == '*' {
218+
i++
219+
} else {
220+
for i < len(s) && isAlphaNum(s[i]) {
221+
i++
222+
}
223+
}
224+
}
225+
if archStart == i {
226+
return nil, fmt.Errorf("at %d: expected GOARCH", i)
227+
}
228+
goarch := s[archStart:i]
229+
230+
out = append(out, filter(fmt.Sprintf("%s/%s", goos, goarch)))
231+
232+
// end or comma
233+
if i == len(s) {
234+
break
235+
}
236+
237+
if s[i] != ',' {
238+
return nil, fmt.Errorf("at %d: unexpected character: %c", i, s[i])
239+
}
240+
i++ // skip ','
241+
242+
if i == len(s) {
243+
return nil, fmt.Errorf("at %d: trailing comma", i-1)
244+
}
245+
246+
_ = start
247+
}
248+
249+
if len(out) == 0 {
250+
return nil, fmt.Errorf("empty filter list")
251+
}
252+
253+
return out, nil
254+
}
255+
178256
// Reads from 'io' on behalf of a path, and returns parsed options.
179257
func scanBuildPath(reader io.Reader, path string) (options, error) {
180258
var opts options
@@ -204,18 +282,22 @@ func scanBuildPath(reader io.Reader, path string) (options, error) {
204282
if dlog {
205283
log.Printf("Found include: %s:%d: %s", path, i, line)
206284
}
207-
rest := strings.Split(strings.TrimPrefix(line, "//go:multibuild:include="), ",")
208-
for _, v := range rest {
209-
opts.Include = append(opts.Include, filter(v))
285+
rest := strings.TrimPrefix(line, "//go:multibuild:include=")
286+
filters, err := validateFilterString(rest)
287+
if err != nil {
288+
return options{}, fmt.Errorf("%s:%d: go:multibuild:include=%s is invalid: %s", path, i, rest, err)
210289
}
290+
opts.Include = filters
211291
} else if strings.HasPrefix(line, "//go:multibuild:exclude=") {
212292
if dlog {
213293
log.Printf("Found exclude: %s:%d: %s", path, i, line)
214294
}
215-
rest := strings.Split(strings.TrimPrefix(line, "//go:multibuild:exclude="), ",")
216-
for _, v := range rest {
217-
opts.Exclude = append(opts.Exclude, filter(v))
295+
rest := strings.TrimPrefix(line, "//go:multibuild:exclude=")
296+
filters, err := validateFilterString(rest)
297+
if err != nil {
298+
return options{}, fmt.Errorf("%s:%d: go:multibuild:exclude=%s is invalid: %s", path, i, rest, err)
218299
}
300+
opts.Exclude = filters
219301
} else {
220302
return options{}, fmt.Errorf("%s:%d: bad go:multibuild instruction: %q", path, i, line)
221303
}

cmd/multibuild/options_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,12 @@ func TestScanBuildPath(t *testing.T) {
175175
want: options{},
176176
wantError: true,
177177
},
178+
{
179+
name: "invalid filter",
180+
input: `//go:multibuild:include=linux/amd*`,
181+
want: options{},
182+
wantError: true,
183+
},
178184
}
179185

180186
equalOptions := func(a, b options) bool {
@@ -423,3 +429,95 @@ func TestValidateTemplate(t *testing.T) {
423429
})
424430
}
425431
}
432+
433+
func TestValidateFilters_Valid(t *testing.T) {
434+
tests := []struct {
435+
name string
436+
in string
437+
want []filter
438+
}{
439+
{
440+
name: "single entry",
441+
in: "linux/amd64",
442+
want: []filter{filter("linux/amd64")},
443+
},
444+
{
445+
name: "multiple entries",
446+
in: "linux/amd64,darwin/arm64",
447+
want: []filter{
448+
filter("linux/amd64"),
449+
filter("darwin/arm64"),
450+
},
451+
},
452+
{
453+
name: "wildcard os",
454+
in: "*/amd64",
455+
want: []filter{filter("*/amd64")},
456+
},
457+
{
458+
name: "wildcard arch",
459+
in: "linux/*",
460+
want: []filter{filter("linux/*")},
461+
},
462+
{
463+
name: "both wildcards",
464+
in: "*/*",
465+
want: []filter{filter("*/*")},
466+
},
467+
{
468+
name: "mixed wildcards",
469+
in: "linux/amd64,*/arm64",
470+
want: []filter{
471+
filter("linux/amd64"),
472+
filter("*/arm64"),
473+
},
474+
},
475+
}
476+
477+
for _, tt := range tests {
478+
t.Run(tt.name, func(t *testing.T) {
479+
got, err := validateFilterString(tt.in)
480+
if err != nil {
481+
t.Fatalf("unexpected error: %v", err)
482+
}
483+
if len(got) != len(tt.want) {
484+
t.Fatalf("len mismatch: got %d want %d", len(got), len(tt.want))
485+
}
486+
for i := range got {
487+
if got[i] != tt.want[i] {
488+
t.Fatalf("entry %d: got %+v want %+v", i, got[i], tt.want[i])
489+
}
490+
}
491+
})
492+
}
493+
}
494+
495+
func TestValidateFilters_Invalid(t *testing.T) {
496+
tests := []struct {
497+
name string
498+
in string
499+
}{
500+
{"empty", ""},
501+
{"missing slash", "linuxamd64"},
502+
{"missing os", "/amd64"},
503+
{"missing arch", "linux/"},
504+
{"double slash", "linux//amd64"},
505+
{"leading comma", ",linux/amd64"},
506+
{"trailing comma", "linux/amd64,"},
507+
{"double comma", "linux/amd64,,darwin/arm64"},
508+
{"unexpected char", "linux/amd$64"},
509+
{"wildcard partial os", "*nix/amd64"},
510+
{"wildcard partial arch", "linux/amd*"},
511+
{"wildcard mixed os", "l*/amd64"},
512+
{"wildcard mixed arch", "linux/*64"},
513+
}
514+
515+
for _, tt := range tests {
516+
t.Run(tt.name, func(t *testing.T) {
517+
_, err := validateFilterString(tt.in)
518+
if err == nil {
519+
t.Fatalf("expected error for %q", tt.in)
520+
}
521+
})
522+
}
523+
}

0 commit comments

Comments
 (0)