Skip to content

Commit 669e488

Browse files
committed
feat: add backend patch validation function with tests
Signed-off-by: nabil salah <nabil.salah203@gmail.com>
1 parent 5e480f2 commit 669e488

4 files changed

Lines changed: 125 additions & 14 deletions

File tree

pkg/gateway/gateway.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"sync"
1313
"time"
1414

15-
"github.com/hashicorp/go-multierror"
1615
"github.com/pkg/errors"
1716
"github.com/rs/zerolog/log"
1817
"github.com/threefoldtech/zbus"
@@ -630,15 +629,8 @@ func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn stri
630629
g.domainLock.Lock()
631630
defer g.domainLock.Unlock()
632631

633-
var errs error
634-
for _, backend := range config.Backends {
635-
if err := backend.Valid(config.TLSPassthrough); err != nil {
636-
errs = multierror.Append(errs, errors.Wrapf(err, "failed to validate backend '%s'", backend))
637-
}
638-
}
639-
640-
if errs != nil {
641-
return errs
632+
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough); err != nil {
633+
return err
642634
}
643635

644636
if _, ok := g.getReservedDomain(fqdn); ok {

pkg/gateway_light/gateway.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -640,10 +640,8 @@ func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn stri
640640
g.domainLock.Lock()
641641
defer g.domainLock.Unlock()
642642

643-
for _, backend := range config.Backends {
644-
if err := backend.Valid(config.TLSPassthrough); err != nil {
645-
return errors.Wrapf(err, "failed to validate backend '%s'", backend)
646-
}
643+
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough); err != nil {
644+
return err
647645
}
648646

649647
if _, ok := g.getReservedDomain(fqdn); ok {

pkg/gridtypes/zos/gw.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/url"
99
"strconv"
1010

11+
"github.com/hashicorp/go-multierror"
1112
"github.com/pkg/errors"
1213
"github.com/threefoldtech/zosbase/pkg/gridtypes"
1314
)
@@ -44,6 +45,16 @@ func (b Backend) Valid(tlsPassthrough bool) error {
4445
return nil
4546
}
4647

48+
func ValidateBackends(backends []Backend, tlsPassthrough bool) error {
49+
var errs error
50+
for _, backend := range backends {
51+
if err := backend.Valid(tlsPassthrough); err != nil {
52+
errs = multierror.Append(errs, errors.Wrapf(err, "failed to validate backend '%s'", backend))
53+
}
54+
}
55+
return errs
56+
}
57+
4758
func asIpPort(a string) (ip net.IP, port uint16, err error) {
4859
h, p, err := net.SplitHostPort(a)
4960
if err != nil {

pkg/gridtypes/zos/gw_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package zos
33
import (
44
"testing"
55

6+
"github.com/hashicorp/go-multierror"
67
"github.com/stretchr/testify/require"
78
)
89

@@ -159,3 +160,112 @@ func TestValidBackendIP6(t *testing.T) {
159160
require.Error(err)
160161
})
161162
}
163+
164+
func TestValidateBackends(t *testing.T) {
165+
require := require.New(t)
166+
167+
t.Run("empty backends", func(t *testing.T) {
168+
backends := []Backend{
169+
"",
170+
}
171+
err := ValidateBackends(backends, true)
172+
require.Error(err)
173+
174+
err = ValidateBackends(backends, false)
175+
require.Error(err)
176+
})
177+
178+
t.Run("all valid backends with tlsPassthrough=true", func(t *testing.T) {
179+
backends := []Backend{
180+
"1.1.1.1:80",
181+
"2.2.2.2:443",
182+
"[2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]:8080",
183+
}
184+
err := ValidateBackends(backends, true)
185+
require.NoError(err)
186+
})
187+
188+
t.Run("all valid backends with tlsPassthrough=false", func(t *testing.T) {
189+
backends := []Backend{
190+
"http://1.1.1.1",
191+
"http://2.2.2.2:443",
192+
"http://[2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF]",
193+
}
194+
err := ValidateBackends(backends, false)
195+
require.NoError(err)
196+
})
197+
198+
t.Run("mixed valid and invalid backends with tlsPassthrough=true", func(t *testing.T) {
199+
backends := []Backend{
200+
"1.1.1.1:80",
201+
"http://2.2.2.2:443", // invalid (should be IP:port without http://)
202+
"2.2.2.2", // invalid (missing port)
203+
"3.3.3.3:port", // invalid (non-numeric port)
204+
"127.0.0.1:8080",
205+
"[::1]:8080",
206+
"[2001:db8::1]:8080",
207+
"2001:db8::1:8080", // invalid (wrong IPv6 format)
208+
}
209+
err := ValidateBackends(backends, true)
210+
require.Error(err)
211+
merr, ok := err.(*multierror.Error)
212+
require.True(ok)
213+
require.Equal(4, len(merr.Errors))
214+
})
215+
216+
t.Run("mixed valid and invalid backends with tlsPassthrough=false", func(t *testing.T) {
217+
backends := []Backend{
218+
"http://1.1.1.1",
219+
"1.1.1.1:80", // invalid (needs http://)
220+
"http://2.2.2.2:443",
221+
"https://3.3.3.3", // invalid (wrong scheme)
222+
"http://localhost", // invalid (loopback)
223+
"http://127.0.0.1", // invalid (loopback)
224+
"http://[::1]", // invalid (loopback)
225+
"http://[2001:db8::1]:8080",
226+
}
227+
err := ValidateBackends(backends, false)
228+
require.Error(err)
229+
// Check that we have the expected number of errors
230+
merr, ok := err.(*multierror.Error)
231+
require.True(ok)
232+
require.Equal(5, len(merr.Errors))
233+
})
234+
235+
t.Run("scheme mismatch using https when not permitted", func(t *testing.T) {
236+
backends := []Backend{
237+
"https://1.1.1.1",
238+
}
239+
err := ValidateBackends(backends, false)
240+
require.Error(err)
241+
})
242+
243+
t.Run("scheme mismatch using http when tlsPassthrough=true", func(t *testing.T) {
244+
backends := []Backend{
245+
"http://1.1.1.1:80",
246+
}
247+
err := ValidateBackends(backends, true)
248+
require.Error(err)
249+
})
250+
251+
t.Run("all invalid backends", func(t *testing.T) {
252+
backends := []Backend{
253+
"invalid",
254+
"1.1.1.1:port",
255+
"http://invalid",
256+
"ftp://1.1.1.1",
257+
}
258+
259+
err := ValidateBackends(backends, true)
260+
require.Error(err)
261+
merr, ok := err.(*multierror.Error)
262+
require.True(ok)
263+
require.Equal(4, len(merr.Errors))
264+
265+
err = ValidateBackends(backends, false)
266+
require.Error(err)
267+
merr, ok = err.(*multierror.Error)
268+
require.True(ok)
269+
require.Equal(4, len(merr.Errors))
270+
})
271+
}

0 commit comments

Comments
 (0)