Skip to content

Commit b0436e6

Browse files
feat: Implement secure dialer to prevent SSRF (#92)
This commit introduces a custom `net.Dialer` for the application's `httpClient` to mitigate Server-Side Request Forgery (SSRF) vulnerabilities. The key changes are: - A new `newSafeDialer` function creates a dialer with a `Control` function that inspects the IP address at connection time. - The dialer blocks connections to private, loopback, and link-local IP addresses. - The global `httpClient` is now configured to use this secure dialer. - The previous, less secure IP validation in `normalizeAndValidateURL` has been removed to eliminate Time-of-Check-to-Time-of-Use (TOCTOU) vulnerabilities. - A new test, `TestSSRFProtection`, has been added to verify that the dialer correctly blocks connections to local addresses. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent 11e9d33 commit b0436e6

2 files changed

Lines changed: 55 additions & 12 deletions

File tree

api/index.go

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"net/http"
1313
"net/url"
1414
"strings"
15+
"syscall"
1516
"time"
1617

1718
"codeberg.org/readeck/go-readability/v2"
@@ -41,6 +42,9 @@ var (
4142
ReadabilityParser = readability.NewParser()
4243
// httpClient used for fetching remote articles with timeouts and redirect policy
4344
httpClient = &http.Client{
45+
Transport: &http.Transport{
46+
DialContext: newSafeDialer().DialContext,
47+
},
4448
Timeout: 10 * time.Second,
4549
CheckRedirect: func(req *http.Request, via []*http.Request) error {
4650
if len(via) >= 5 {
@@ -53,6 +57,30 @@ var (
5357
maxContentBytes = int64(2 * 1024 * 1024)
5458
)
5559

60+
func newSafeDialer() *net.Dialer {
61+
dialer := &net.Dialer{
62+
Timeout: 30 * time.Second,
63+
KeepAlive: 30 * time.Second,
64+
Control: func(network, address string, c syscall.RawConn) error {
65+
host, _, err := net.SplitHostPort(address)
66+
if err != nil {
67+
return err
68+
}
69+
ips, err := net.LookupIP(host)
70+
if err != nil {
71+
return err
72+
}
73+
for _, ip := range ips {
74+
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
75+
return errors.New("refusing to connect to private network address")
76+
}
77+
}
78+
return nil
79+
},
80+
}
81+
return dialer
82+
}
83+
5684
const defaultUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.0"
5785

5886
func fetchAndParse(ctx context.Context, link *url.URL, userAgent string) (readability.Article, error) {
@@ -99,16 +127,6 @@ func normalizeAndValidateURL(rawLink string) (*url.URL, error) {
99127
if link.Scheme != "http" && link.Scheme != "https" {
100128
return nil, errors.New("unsupported URL scheme")
101129
}
102-
host := link.Hostname()
103-
// resolve and block private IPs
104-
ips, err := net.LookupIP(host)
105-
if err == nil {
106-
for _, ip := range ips {
107-
if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
108-
return nil, errors.New("refusing private network address")
109-
}
110-
}
111-
}
112130
return link, nil
113131
}
114132

api/index_test.go

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ func TestNormalizeAndValidateURL(t *testing.T) {
1919
{"example.com", "https://example.com", false},
2020
{"http://foo.bar", "http://foo.bar", false},
2121
{"ftp://foo.bar", "", true},
22-
{"127.0.0.1", "", true},
23-
{"192.168.0.5/path", "", true},
2422
}
2523
for _, tt := range tests {
2624
u, err := normalizeAndValidateURL(tt.raw)
@@ -77,3 +75,30 @@ func TestFetchAndParse(t *testing.T) {
7775
t.Errorf("Article.Content missing expected paragraph, got: %q", content.String())
7876
}
7977
}
78+
79+
func TestSSRFProtection(t *testing.T) {
80+
// a dummy server that should never be reached
81+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82+
t.Fatal("dialer did not block private IP, connection was made")
83+
}))
84+
defer srv.Close()
85+
86+
// get loopback address of the server
87+
// srv.URL will be something like http://127.0.0.1:54321
88+
// we want to test if the dialer blocks the connection to 127.0.0.1
89+
// so, we don't use the server's client, we use our own httpClient
90+
req, err := http.NewRequest("GET", srv.URL, nil)
91+
if err != nil {
92+
t.Fatalf("failed to create request: %v", err)
93+
}
94+
95+
_, err = httpClient.Do(req)
96+
if err == nil {
97+
t.Fatal("expected an error when dialing a private IP, but got none")
98+
}
99+
// check if the error is the one we expect from our dialer
100+
// the error is wrapped, so we need to check for the substring
101+
if !strings.Contains(err.Error(), "refusing to connect to private network address") {
102+
t.Errorf("expected error to contain 'refusing to connect to private network address', but got: %v", err)
103+
}
104+
}

0 commit comments

Comments
 (0)