Skip to content

Commit b508541

Browse files
authored
enhance backend validation for gateway to prevent infinite loops (#97)
* refactor: enhance backend validation to prevent infinite loops * refactor: handle errors when loading public config
1 parent 2c3b580 commit b508541

3 files changed

Lines changed: 60 additions & 10 deletions

File tree

pkg/gateway/gateway.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ func (g *gatewayModule) SetNamedProxy(wlID string, config zos.GatewayNameProxy)
588588
},
589589
}
590590

591-
if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase); err != nil {
591+
if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase, cfg.IPv4.IP, cfg.IPv6.IP); err != nil {
592592
return "", err
593593
}
594594

@@ -622,14 +622,14 @@ func (g *gatewayModule) SetFQDNProxy(wlID string, config zos.GatewayFQDNProxy) e
622622
},
623623
}
624624

625-
return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase)
625+
return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase, cfg.IPv4.IP, cfg.IPv6.IP)
626626
}
627627

628-
func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase) error {
628+
func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase, nodeIPs ...net.IP) error {
629629
g.domainLock.Lock()
630630
defer g.domainLock.Unlock()
631631

632-
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough); err != nil {
632+
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough, nodeIPs...); err != nil {
633633
return err
634634
}
635635

pkg/gateway_light/gateway.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,13 @@ func (g *gatewayModule) SetNamedProxy(wlID string, config zos.GatewayNameProxy)
584584
return "", errors.New("node doesn't support name proxy (doesn't have a domain)")
585585
}
586586

587+
// Get public config for node IP validation
588+
netStub := stubs.NewNetworkerLightStub(g.cl)
589+
pubConfig, err := netStub.LoadPublicConfig(ctx)
590+
if err != nil {
591+
return "", errors.Wrap(err, "failed to load public config")
592+
}
593+
587594
if err := g.validateNameContract(config.Name, twinID); err != nil {
588595
return "", errors.Wrap(err, "failed to verify name contract")
589596
}
@@ -599,7 +606,7 @@ func (g *gatewayModule) SetNamedProxy(wlID string, config zos.GatewayNameProxy)
599606
},
600607
}
601608

602-
if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase); err != nil {
609+
if err := g.setupRouting(ctx, wlID, fqdn, gatewayTLSConfig, config.GatewayBase, pubConfig.IPv4.IP, pubConfig.IPv6.IP); err != nil {
603610
return "", err
604611
}
605612

@@ -618,6 +625,13 @@ func (g *gatewayModule) SetFQDNProxy(wlID string, config zos.GatewayFQDNProxy) e
618625
return err
619626
}
620627

628+
// Get public config for node IP validation
629+
netStub := stubs.NewNetworkerLightStub(g.cl)
630+
pubConfig, err := netStub.LoadPublicConfig(ctx)
631+
if err != nil {
632+
return errors.Wrap(err, "failed to load public config")
633+
}
634+
621635
if domain != "" && strings.HasSuffix(config.FQDN, domain) {
622636
return errors.New("can't create a fqdn workload with a subdomain of the gateway's managed domain")
623637
}
@@ -633,14 +647,14 @@ func (g *gatewayModule) SetFQDNProxy(wlID string, config zos.GatewayFQDNProxy) e
633647
},
634648
}
635649

636-
return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase)
650+
return g.setupRouting(ctx, wlID, config.FQDN, gatewayTLSConfig, config.GatewayBase, pubConfig.IPv4.IP, pubConfig.IPv6.IP)
637651
}
638652

639-
func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase) error {
653+
func (g *gatewayModule) setupRouting(ctx context.Context, wlID string, fqdn string, tlsConfig TlsConfig, config zos.GatewayBase, nodeIPs ...net.IP) error {
640654
g.domainLock.Lock()
641655
defer g.domainLock.Unlock()
642656

643-
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough); err != nil {
657+
if err := zos.ValidateBackends(config.Backends, config.TLSPassthrough, nodeIPs...); err != nil {
644658
return err
645659
}
646660

pkg/gridtypes/zos/gw.go

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"math"
77
"net"
88
"net/url"
9+
"slices"
910
"strconv"
1011

1112
"github.com/hashicorp/go-multierror"
@@ -45,14 +46,49 @@ func (b Backend) Valid(tlsPassthrough bool) error {
4546
return nil
4647
}
4748

48-
func ValidateBackends(backends []Backend, tlsPassthrough bool) error {
49+
func ValidateBackends(backends []Backend, tlsPassthrough bool, nodeIPs ...net.IP) error {
4950
var errs error
5051
for _, backend := range backends {
5152
if err := backend.Valid(tlsPassthrough); err != nil {
5253
errs = multierror.Append(errs, errors.Wrapf(err, "failed to validate backend '%s'", backend))
5354
}
5455
}
55-
return errs
56+
if errs != nil {
57+
return errs
58+
}
59+
60+
// Check that backends don't point to the node's own public IPs (prevents infinite loops)
61+
for _, backend := range backends {
62+
backendIP, err := backend.ExtractIP()
63+
if err != nil {
64+
return errors.Wrapf(err, "failed to extract IP from backend '%s'", backend)
65+
}
66+
if slices.ContainsFunc(nodeIPs, backendIP.Equal) {
67+
return fmt.Errorf("backend %s points to the node's own public IP address", backend)
68+
}
69+
}
70+
return nil
71+
}
72+
73+
// ExtractIP extracts the IP address from a backend string.
74+
func (b Backend) ExtractIP() (net.IP, error) {
75+
// Try ip:port format first
76+
if ip, _, err := asIpPort(string(b)); err == nil {
77+
return ip, nil
78+
}
79+
80+
// Try URL format
81+
u, err := url.Parse(string(b))
82+
if err != nil {
83+
return nil, fmt.Errorf("failed to parse backend: %w", err)
84+
}
85+
86+
ip := net.ParseIP(u.Hostname())
87+
if ip == nil {
88+
return nil, fmt.Errorf("invalid ip address in backend: %s", u.Hostname())
89+
}
90+
91+
return ip, nil
5692
}
5793

5894
func asIpPort(a string) (ip net.IP, port uint16, err error) {

0 commit comments

Comments
 (0)