Skip to content

Commit 81be506

Browse files
committed
more unit tests
1 parent ef9018e commit 81be506

4 files changed

Lines changed: 305 additions & 6 deletions

File tree

internal/helper/google.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,19 @@ func FetchGoogleCrawlerIPs(log *slog.Logger, httpClient *http.Client, urls []str
137137
return ReduceCIDRs(allCIDRs, log), nil
138138
}
139139

140+
// RefreshGoogleCrawlerIPs fetches crawler IPs from all configured URLs and updates
141+
// the provided GooglebotIPs set. Returns the number of CIDRs loaded.
142+
func RefreshGoogleCrawlerIPs(log *slog.Logger, httpClient *http.Client, target *GooglebotIPs, urls []string) (int, error) {
143+
cidrs, err := FetchGoogleCrawlerIPs(log, httpClient, urls)
144+
if err != nil {
145+
return 0, err
146+
}
147+
148+
target.Update(cidrs, log)
149+
150+
return len(cidrs), nil
151+
}
152+
140153
// ReduceCIDRs canonicalizes CIDRs, removes exact duplicates, and removes narrower
141154
// ranges when they are fully covered by broader ranges.
142155
func ReduceCIDRs(cidrs []string, log *slog.Logger) []string {

internal/helper/google_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,71 @@ func TestFetchGoogleCrawlerIPs(t *testing.T) {
124124
t.Fatalf("unexpected CIDRs: got %v want %v", got, want)
125125
}
126126
}
127+
128+
func TestFetchGoogleCrawlerIPsError(t *testing.T) {
129+
log := slog.New(slog.NewTextHandler(os.Stdout, nil))
130+
131+
okServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
132+
w.Header().Set("Content-Type", "application/json")
133+
_, _ = w.Write([]byte(`{"prefixes":[{"ipv4Prefix":"8.8.8.0/24"}]}`))
134+
}))
135+
defer okServer.Close()
136+
137+
errServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
138+
http.Error(w, "boom", http.StatusInternalServerError)
139+
}))
140+
defer errServer.Close()
141+
142+
_, err := FetchGoogleCrawlerIPs(log, okServer.Client(), []string{okServer.URL, errServer.URL})
143+
if err == nil {
144+
t.Fatal("expected error when one endpoint returns non-200")
145+
}
146+
}
147+
148+
func TestRefreshGoogleCrawlerIPs(t *testing.T) {
149+
log := slog.New(slog.NewTextHandler(os.Stdout, nil))
150+
g := NewGooglebotIPs()
151+
152+
serverA := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
153+
w.Header().Set("Content-Type", "application/json")
154+
_, _ = w.Write([]byte(`{"prefixes":[{"ipv4Prefix":"203.0.113.0/24"}]}`))
155+
}))
156+
defer serverA.Close()
157+
158+
serverB := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
159+
w.Header().Set("Content-Type", "application/json")
160+
_, _ = w.Write([]byte(`{"prefixes":[{"ipv4Prefix":"203.0.113.0/25"},{"ipv6Prefix":"2001:db8::/32"}]}`))
161+
}))
162+
defer serverB.Close()
163+
164+
count, err := RefreshGoogleCrawlerIPs(log, serverA.Client(), g, []string{serverA.URL, serverB.URL})
165+
if err != nil {
166+
t.Fatalf("RefreshGoogleCrawlerIPs failed: %v", err)
167+
}
168+
169+
if count != 2 {
170+
t.Fatalf("expected reduced count 2, got %d", count)
171+
}
172+
173+
if !g.Contains(net.ParseIP("203.0.113.9")) {
174+
t.Fatal("expected refreshed set to contain 203.0.113.9")
175+
}
176+
if !g.Contains(net.ParseIP("2001:db8::1")) {
177+
t.Fatal("expected refreshed set to contain 2001:db8::1")
178+
}
179+
}
180+
181+
func TestRefreshGoogleCrawlerIPsError(t *testing.T) {
182+
log := slog.New(slog.NewTextHandler(os.Stdout, nil))
183+
g := NewGooglebotIPs()
184+
185+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
186+
http.Error(w, "boom", http.StatusInternalServerError)
187+
}))
188+
defer server.Close()
189+
190+
_, err := RefreshGoogleCrawlerIPs(log, server.Client(), g, []string{server.URL})
191+
if err == nil {
192+
t.Fatal("expected refresh error")
193+
}
194+
}

main.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,24 +351,22 @@ func (bc *CaptchaProtect) googlebotIPCheckLoop(ctx context.Context) {
351351
defer ticker.Stop()
352352

353353
// Initial fetch
354-
cidrs, err := helper.FetchGoogleCrawlerIPs(bc.log, bc.httpClient, helper.GoogleCrawlerIPRangeURLs)
354+
count, err := helper.RefreshGoogleCrawlerIPs(bc.log, bc.httpClient, bc.googlebotIPs, helper.GoogleCrawlerIPRangeURLs)
355355
if err != nil {
356356
bc.log.Error("failed to fetch googlebot ips", "err", err)
357357
} else {
358-
bc.googlebotIPs.Update(cidrs, bc.log)
359-
bc.log.Info("Updated Googlebot IPs", "count", len(cidrs))
358+
bc.log.Info("Updated Googlebot IPs", "count", count)
360359
}
361360

362361
for {
363362
select {
364363
case <-ticker.C:
365-
cidrs, err := helper.FetchGoogleCrawlerIPs(bc.log, bc.httpClient, helper.GoogleCrawlerIPRangeURLs)
364+
count, err := helper.RefreshGoogleCrawlerIPs(bc.log, bc.httpClient, bc.googlebotIPs, helper.GoogleCrawlerIPRangeURLs)
366365
if err != nil {
367366
bc.log.Error("failed to fetch googlebot ips", "err", err)
368367
continue
369368
}
370-
bc.googlebotIPs.Update(cidrs, bc.log)
371-
bc.log.Info("Updated Googlebot IPs", "count", len(cidrs))
369+
bc.log.Info("Updated Googlebot IPs", "count", count)
372370
case <-ctx.Done():
373371
return
374372
}

main_test.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"strings"
1515
"testing"
1616
"time"
17+
18+
"github.com/libops/captcha-protect/internal/helper"
1719
)
1820

1921
func TestParseIp(t *testing.T) {
@@ -1629,3 +1631,221 @@ func TestPojChallengeGeneration(t *testing.T) {
16291631
t.Errorf("Expected PoJ JS URL in challenge page")
16301632
}
16311633
}
1634+
1635+
func TestPerformHealthCheckSuccessResetsFailures(t *testing.T) {
1636+
config := CreateConfig()
1637+
config.SiteKey = "test"
1638+
config.SecretKey = "test"
1639+
config.ProtectRoutes = []string{"/"}
1640+
config.CaptchaProvider = "turnstile"
1641+
config.PeriodSeconds = 3600
1642+
config.FailureThreshold = 2
1643+
1644+
ctx, cancel := context.WithCancel(context.Background())
1645+
defer cancel()
1646+
1647+
bc, err := NewCaptchaProtect(ctx, nil, config, "test")
1648+
if err != nil {
1649+
t.Fatalf("Failed to create CaptchaProtect: %v", err)
1650+
}
1651+
1652+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1653+
w.WriteHeader(http.StatusOK)
1654+
}))
1655+
defer server.Close()
1656+
1657+
bc.captchaConfig.js = server.URL
1658+
bc.recordHealthCheckFailure()
1659+
1660+
bc.performHealthCheck()
1661+
1662+
bc.mu.RLock()
1663+
defer bc.mu.RUnlock()
1664+
if bc.healthCheckFailureCount != 0 {
1665+
t.Fatalf("expected failure count reset to 0, got %d", bc.healthCheckFailureCount)
1666+
}
1667+
if bc.circuitState != circuitClosed {
1668+
t.Fatalf("expected circuit to be closed, got %v", bc.circuitState)
1669+
}
1670+
}
1671+
1672+
func TestPerformHealthCheckFailurePaths(t *testing.T) {
1673+
tests := []struct {
1674+
name string
1675+
jsURL string
1676+
status int
1677+
expectErr bool
1678+
}{
1679+
{
1680+
name: "404 considered failure",
1681+
status: http.StatusNotFound,
1682+
},
1683+
{
1684+
name: "503 considered failure",
1685+
status: http.StatusServiceUnavailable,
1686+
},
1687+
{
1688+
name: "invalid URL request creation failure",
1689+
jsURL: "://invalid-url",
1690+
expectErr: true,
1691+
},
1692+
}
1693+
1694+
for _, tt := range tests {
1695+
t.Run(tt.name, func(t *testing.T) {
1696+
config := CreateConfig()
1697+
config.SiteKey = "test"
1698+
config.SecretKey = "test"
1699+
config.ProtectRoutes = []string{"/"}
1700+
config.CaptchaProvider = "turnstile"
1701+
config.PeriodSeconds = 3600
1702+
config.FailureThreshold = 1
1703+
1704+
ctx, cancel := context.WithCancel(context.Background())
1705+
defer cancel()
1706+
1707+
bc, err := NewCaptchaProtect(ctx, nil, config, "test")
1708+
if err != nil {
1709+
t.Fatalf("Failed to create CaptchaProtect: %v", err)
1710+
}
1711+
1712+
if tt.expectErr {
1713+
bc.captchaConfig.js = tt.jsURL
1714+
} else {
1715+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1716+
w.WriteHeader(tt.status)
1717+
}))
1718+
defer server.Close()
1719+
bc.captchaConfig.js = server.URL
1720+
}
1721+
1722+
bc.performHealthCheck()
1723+
1724+
bc.mu.RLock()
1725+
defer bc.mu.RUnlock()
1726+
if bc.healthCheckFailureCount != 1 {
1727+
t.Fatalf("expected failure count 1, got %d", bc.healthCheckFailureCount)
1728+
}
1729+
if bc.circuitState != circuitOpen {
1730+
t.Fatalf("expected circuit to open, got %v", bc.circuitState)
1731+
}
1732+
})
1733+
}
1734+
}
1735+
1736+
func TestVerifyChallengePagePojFallbackUsesOneHourTTL(t *testing.T) {
1737+
config := CreateConfig()
1738+
config.SiteKey = "test-key"
1739+
config.SecretKey = "test-secret"
1740+
config.ProtectRoutes = []string{"/"}
1741+
config.CaptchaProvider = "turnstile"
1742+
config.PeriodSeconds = 3600
1743+
config.FailureThreshold = 1
1744+
1745+
ctx, cancel := context.WithCancel(context.Background())
1746+
defer cancel()
1747+
1748+
bc, err := NewCaptchaProtect(ctx, nil, config, "test")
1749+
if err != nil {
1750+
t.Fatalf("Failed to create CaptchaProtect: %v", err)
1751+
}
1752+
1753+
// Open the circuit so PoJ becomes active fallback provider.
1754+
bc.recordHealthCheckFailure()
1755+
1756+
form := url.Values{}
1757+
form.Add("poj-captcha-response", "ok")
1758+
form.Add("destination", "%2F")
1759+
req := httptest.NewRequest(http.MethodPost, "/challenge", strings.NewReader(form.Encode()))
1760+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
1761+
rr := httptest.NewRecorder()
1762+
clientIP := "203.0.113.10"
1763+
1764+
status := bc.verifyChallengePage(rr, req, clientIP)
1765+
if status != http.StatusFound {
1766+
t.Fatalf("expected status %d, got %d", http.StatusFound, status)
1767+
}
1768+
1769+
item, found := bc.verifiedCache.Items()[clientIP]
1770+
if !found {
1771+
t.Fatalf("expected %s to be in verified cache", clientIP)
1772+
}
1773+
1774+
remaining := time.Until(time.Unix(0, item.Expiration))
1775+
if remaining < 50*time.Minute || remaining > 70*time.Minute {
1776+
t.Fatalf("expected PoJ fallback TTL around 1h, got %s", remaining)
1777+
}
1778+
}
1779+
1780+
func TestGooglebotIPCheckLoopInitialFetchSuccess(t *testing.T) {
1781+
originalURLs := helper.GoogleCrawlerIPRangeURLs
1782+
defer func() { helper.GoogleCrawlerIPRangeURLs = originalURLs }()
1783+
1784+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1785+
w.Header().Set("Content-Type", "application/json")
1786+
_, _ = w.Write([]byte(`{"prefixes":[{"ipv4Prefix":"203.0.113.0/24"}]}`))
1787+
}))
1788+
defer server.Close()
1789+
1790+
helper.GoogleCrawlerIPRangeURLs = []string{server.URL}
1791+
1792+
bc := &CaptchaProtect{
1793+
log: slog.New(slog.NewTextHandler(os.Stdout, nil)),
1794+
httpClient: server.Client(),
1795+
googlebotIPs: helper.NewGooglebotIPs(),
1796+
}
1797+
1798+
ctx, cancel := context.WithCancel(context.Background())
1799+
done := make(chan struct{})
1800+
go func() {
1801+
bc.googlebotIPCheckLoop(ctx)
1802+
close(done)
1803+
}()
1804+
1805+
deadline := time.Now().Add(2 * time.Second)
1806+
for time.Now().Before(deadline) {
1807+
if bc.googlebotIPs.Contains(net.ParseIP("203.0.113.10")) {
1808+
cancel()
1809+
<-done
1810+
return
1811+
}
1812+
time.Sleep(20 * time.Millisecond)
1813+
}
1814+
1815+
cancel()
1816+
<-done
1817+
t.Fatal("expected googlebot IPs to be updated from initial crawler fetch")
1818+
}
1819+
1820+
func TestGooglebotIPCheckLoopInitialFetchError(t *testing.T) {
1821+
originalURLs := helper.GoogleCrawlerIPRangeURLs
1822+
defer func() { helper.GoogleCrawlerIPRangeURLs = originalURLs }()
1823+
1824+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1825+
http.Error(w, "boom", http.StatusInternalServerError)
1826+
}))
1827+
defer server.Close()
1828+
1829+
helper.GoogleCrawlerIPRangeURLs = []string{server.URL}
1830+
1831+
bc := &CaptchaProtect{
1832+
log: slog.New(slog.NewTextHandler(os.Stdout, nil)),
1833+
httpClient: server.Client(),
1834+
googlebotIPs: helper.NewGooglebotIPs(),
1835+
}
1836+
1837+
ctx, cancel := context.WithCancel(context.Background())
1838+
done := make(chan struct{})
1839+
go func() {
1840+
bc.googlebotIPCheckLoop(ctx)
1841+
close(done)
1842+
}()
1843+
1844+
time.Sleep(100 * time.Millisecond)
1845+
cancel()
1846+
<-done
1847+
1848+
if bc.googlebotIPs.Contains(net.ParseIP("203.0.113.10")) {
1849+
t.Fatal("did not expect googlebot IPs to update when initial fetch fails")
1850+
}
1851+
}

0 commit comments

Comments
 (0)