Skip to content

Commit 3e97fc1

Browse files
authored
Merge pull request #6684 from snyk/fix/CLI-1402-use-gaf-networking-for-license-downloads
fix: Use GAF network stack for license downloads
2 parents d748341 + 835194f commit 3e97fc1

2 files changed

Lines changed: 91 additions & 44 deletions

File tree

cliv2/scripts/prepare_licenses.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,36 @@ import (
99
"path/filepath"
1010
"strings"
1111
"time"
12+
13+
"github.com/snyk/go-application-framework/pkg/configuration"
14+
"github.com/snyk/go-application-framework/pkg/networking"
15+
"github.com/snyk/go-application-framework/pkg/networking/middleware"
1216
)
1317

1418
// licensesEmbeddedDir is the cliv2-relative tree where go-licenses and manual downloads write.
1519
var licensesEmbeddedDir = filepath.Join(".", "internal", "embedded", "_data", "licenses")
1620

21+
const (
22+
maxDownloadAttempts = 5
23+
perAttemptTimeout = 30 * time.Second
24+
)
25+
1726
func log(msg string) {
1827
fmt.Fprintln(os.Stderr, msg)
1928
}
2029

30+
func newHTTPClient() *http.Client {
31+
cfg := configuration.NewWithOpts()
32+
cfg.Set(middleware.ConfigurationKeyRequestAttempts, maxDownloadAttempts)
33+
34+
na := networking.NewNetworkAccess(cfg)
35+
na.AddHeaderField("User-Agent", "Snyk-CLI-Build/1.0")
36+
37+
client := na.GetUnauthorizedHttpClient()
38+
client.Timeout = time.Duration(maxDownloadAttempts) * perAttemptTimeout
39+
return client
40+
}
41+
2142
func main() {
2243
log("Preparing 3rd party licenses...")
2344

@@ -50,6 +71,8 @@ func main() {
5071
os.Exit(1)
5172
}
5273

74+
client := newHTTPClient()
75+
5376
log("Downloading manual licenses...")
5477
manualLicenses := []struct{ url, pkg string }{
5578
{"https://raw.githubusercontent.com/davecgh/go-spew/master/LICENSE", "github.com/davecgh/go-spew"},
@@ -58,7 +81,7 @@ func main() {
5881
{"https://go.dev/LICENSE?m=text", "go.dev"},
5982
}
6083
for _, lic := range manualLicenses {
61-
if err := manualLicenseDownload(lic.url, lic.pkg); err != nil {
84+
if err := manualLicenseDownload(client, lic.url, lic.pkg); err != nil {
6285
log(fmt.Sprintf("Error downloading license: %v", err))
6386
os.Exit(1)
6487
}
@@ -72,7 +95,7 @@ func main() {
7295
log("Done preparing 3rd party licenses.")
7396
}
7497

75-
func manualLicenseDownload(url, packageName string) error {
98+
func manualLicenseDownload(client *http.Client, url, packageName string) error {
7699
folderPath := filepath.Join(licensesEmbeddedDir, packageName)
77100
licenseFile := filepath.Join(folderPath, "LICENSE")
78101

@@ -86,14 +109,7 @@ func manualLicenseDownload(url, packageName string) error {
86109
return fmt.Errorf("creating directory for %s: %w", packageName, err)
87110
}
88111

89-
req, err := http.NewRequest(http.MethodGet, url, nil)
90-
if err != nil {
91-
return fmt.Errorf("creating request for %s: %w", packageName, err)
92-
}
93-
req.Header.Set("User-Agent", "Snyk-CLI-Build/1.0")
94-
95-
client := &http.Client{Timeout: 30 * time.Second}
96-
resp, err := client.Do(req)
112+
resp, err := client.Get(url)
97113
if err != nil {
98114
return fmt.Errorf("downloading license for %s: %w", packageName, err)
99115
}

cliv2/scripts/prepare_licenses_test.go

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66
"os"
77
"path/filepath"
88
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
912
)
1013

1114
func TestIsEmbeddedLicenseFileName(t *testing.T) {
@@ -44,14 +47,10 @@ func TestIsEmbeddedLicenseFileName(t *testing.T) {
4447
}
4548

4649
for _, name := range keep {
47-
if !isEmbeddedLicenseFileName(name) {
48-
t.Errorf("expected %q to be kept, but it would be removed", name)
49-
}
50+
assert.True(t, isEmbeddedLicenseFileName(name), "expected %q to be kept, but it would be removed", name)
5051
}
5152
for _, name := range remove {
52-
if isEmbeddedLicenseFileName(name) {
53-
t.Errorf("expected %q to be removed, but it would be kept", name)
54-
}
53+
assert.False(t, isEmbeddedLicenseFileName(name), "expected %q to be removed, but it would be kept", name)
5554
}
5655
}
5756

@@ -62,9 +61,7 @@ func TestCleanupNonLicenseFiles(t *testing.T) {
6261
t.Cleanup(func() { licensesEmbeddedDir = origDir })
6362

6463
pkgDir := filepath.Join(tmpDir, "example.com", "foo")
65-
if err := os.MkdirAll(pkgDir, 0o755); err != nil {
66-
t.Fatal(err)
67-
}
64+
require.NoError(t, os.MkdirAll(pkgDir, 0o755))
6865

6966
files := map[string]bool{
7067
"LICENSE": true,
@@ -75,21 +72,15 @@ func TestCleanupNonLicenseFiles(t *testing.T) {
7572
"main.go": false,
7673
}
7774
for name := range files {
78-
if err := os.WriteFile(filepath.Join(pkgDir, name), []byte("test"), 0o644); err != nil {
79-
t.Fatal(err)
80-
}
75+
require.NoError(t, os.WriteFile(filepath.Join(pkgDir, name), []byte("test"), 0o644))
8176
}
8277

83-
if err := cleanupNonLicenseFiles(); err != nil {
84-
t.Fatalf("cleanupNonLicenseFiles() error: %v", err)
85-
}
78+
require.NoError(t, cleanupNonLicenseFiles())
8679

8780
for name, shouldExist := range files {
8881
_, err := os.Stat(filepath.Join(pkgDir, name))
8982
exists := err == nil
90-
if exists != shouldExist {
91-
t.Errorf("file %q: exists=%v, want exists=%v", name, exists, shouldExist)
92-
}
83+
assert.Equal(t, shouldExist, exists, "file %q", name)
9384
}
9485
}
9586

@@ -107,22 +98,15 @@ func TestManualLicenseDownload(t *testing.T) {
10798
t.Cleanup(func() { licensesEmbeddedDir = origDir })
10899

109100
pkg := "example.com/testpkg"
110-
if err := manualLicenseDownload(server.URL+"/LICENSE", pkg); err != nil {
111-
t.Fatalf("manualLicenseDownload() error: %v", err)
112-
}
101+
client := &http.Client{}
102+
require.NoError(t, manualLicenseDownload(client, server.URL+"/LICENSE", pkg))
113103

114104
content, err := os.ReadFile(filepath.Join(tmpDir, pkg, "LICENSE"))
115-
if err != nil {
116-
t.Fatalf("reading downloaded license: %v", err)
117-
}
118-
if string(content) != "MIT License\n" {
119-
t.Errorf("license content = %q, want %q", content, "MIT License\n")
120-
}
105+
require.NoError(t, err)
106+
assert.Equal(t, "MIT License\n", string(content))
121107

122108
// Second call should skip (file already exists).
123-
if err := manualLicenseDownload(server.URL+"/LICENSE", pkg); err != nil {
124-
t.Fatalf("second manualLicenseDownload() error: %v", err)
125-
}
109+
require.NoError(t, manualLicenseDownload(client, server.URL+"/LICENSE", pkg))
126110
}
127111

128112
func TestManualLicenseDownloadHTTPError(t *testing.T) {
@@ -136,8 +120,55 @@ func TestManualLicenseDownloadHTTPError(t *testing.T) {
136120
licensesEmbeddedDir = tmpDir
137121
t.Cleanup(func() { licensesEmbeddedDir = origDir })
138122

139-
err := manualLicenseDownload(server.URL+"/LICENSE", "example.com/missing")
140-
if err == nil {
141-
t.Fatal("expected error for 404 response, got nil")
142-
}
123+
client := &http.Client{}
124+
err := manualLicenseDownload(client, server.URL+"/LICENSE", "example.com/missing")
125+
assert.Error(t, err, "expected error for 404 response, got nil")
126+
}
127+
128+
func TestNewHTTPClient_SetsUserAgent(t *testing.T) {
129+
var gotUA string
130+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
131+
gotUA = r.Header.Get("User-Agent")
132+
_, err := w.Write([]byte("OK"))
133+
if err != nil {
134+
http.Error(w, err.Error(), http.StatusInternalServerError)
135+
return
136+
}
137+
}))
138+
defer server.Close()
139+
140+
client := newHTTPClient()
141+
resp, err := client.Get(server.URL)
142+
require.NoError(t, err)
143+
require.Equal(t, http.StatusOK, resp.StatusCode)
144+
assert.NoError(t, resp.Body.Close())
145+
146+
assert.Equal(t, "Snyk-CLI-Build/1.0", gotUA)
147+
}
148+
149+
func TestNewHTTPClient_RetriesOn429(t *testing.T) {
150+
var calls int
151+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152+
calls++
153+
n := calls
154+
if n <= 2 {
155+
w.Header().Set("Retry-After", "0")
156+
w.WriteHeader(http.StatusTooManyRequests)
157+
return
158+
}
159+
_, err := w.Write([]byte("OK"))
160+
if err != nil {
161+
http.Error(w, err.Error(), http.StatusInternalServerError)
162+
return
163+
}
164+
}))
165+
defer server.Close()
166+
167+
client := newHTTPClient()
168+
resp, err := client.Get(server.URL)
169+
require.NoError(t, err)
170+
require.Equal(t, http.StatusOK, resp.StatusCode)
171+
assert.NoError(t, resp.Body.Close())
172+
173+
assert.GreaterOrEqual(t, calls, 3, "expected at least 3 server calls (2 x 429 + 1 x 200)")
143174
}

0 commit comments

Comments
 (0)