66 "os"
77 "path/filepath"
88 "testing"
9+
10+ "github.com/stretchr/testify/assert"
11+ "github.com/stretchr/testify/require"
912)
1013
1114func TestIsEmbeddedLicenseFileName (t * testing.T ) {
@@ -44,14 +47,10 @@ func TestIsEmbeddedLicenseFileName(t *testing.T) {
4447 }
4548
4649 for _ , name := range keep {
47- if ! isEmbeddedLicenseFileName (name ) {
48- t .Errorf ("expected %q to be kept, but it would be removed" , name )
49- }
50+ assert .True (t , isEmbeddedLicenseFileName (name ), "expected %q to be kept, but it would be removed" , name )
5051 }
5152 for _ , name := range remove {
52- if isEmbeddedLicenseFileName (name ) {
53- t .Errorf ("expected %q to be removed, but it would be kept" , name )
54- }
53+ assert .False (t , isEmbeddedLicenseFileName (name ), "expected %q to be removed, but it would be kept" , name )
5554 }
5655}
5756
@@ -62,9 +61,7 @@ func TestCleanupNonLicenseFiles(t *testing.T) {
6261 t .Cleanup (func () { licensesEmbeddedDir = origDir })
6362
6463 pkgDir := filepath .Join (tmpDir , "example.com" , "foo" )
65- if err := os .MkdirAll (pkgDir , 0o755 ); err != nil {
66- t .Fatal (err )
67- }
64+ require .NoError (t , os .MkdirAll (pkgDir , 0o755 ))
6865
6966 files := map [string ]bool {
7067 "LICENSE" : true ,
@@ -75,21 +72,15 @@ func TestCleanupNonLicenseFiles(t *testing.T) {
7572 "main.go" : false ,
7673 }
7774 for name := range files {
78- if err := os .WriteFile (filepath .Join (pkgDir , name ), []byte ("test" ), 0o644 ); err != nil {
79- t .Fatal (err )
80- }
75+ require .NoError (t , os .WriteFile (filepath .Join (pkgDir , name ), []byte ("test" ), 0o644 ))
8176 }
8277
83- if err := cleanupNonLicenseFiles (); err != nil {
84- t .Fatalf ("cleanupNonLicenseFiles() error: %v" , err )
85- }
78+ require .NoError (t , cleanupNonLicenseFiles ())
8679
8780 for name , shouldExist := range files {
8881 _ , err := os .Stat (filepath .Join (pkgDir , name ))
8982 exists := err == nil
90- if exists != shouldExist {
91- t .Errorf ("file %q: exists=%v, want exists=%v" , name , exists , shouldExist )
92- }
83+ assert .Equal (t , shouldExist , exists , "file %q" , name )
9384 }
9485}
9586
@@ -107,22 +98,15 @@ func TestManualLicenseDownload(t *testing.T) {
10798 t .Cleanup (func () { licensesEmbeddedDir = origDir })
10899
109100 pkg := "example.com/testpkg"
110- if err := manualLicenseDownload (server .URL + "/LICENSE" , pkg ); err != nil {
111- t .Fatalf ("manualLicenseDownload() error: %v" , err )
112- }
101+ client := & http.Client {}
102+ require .NoError (t , manualLicenseDownload (client , server .URL + "/LICENSE" , pkg ))
113103
114104 content , err := os .ReadFile (filepath .Join (tmpDir , pkg , "LICENSE" ))
115- if err != nil {
116- t .Fatalf ("reading downloaded license: %v" , err )
117- }
118- if string (content ) != "MIT License\n " {
119- t .Errorf ("license content = %q, want %q" , content , "MIT License\n " )
120- }
105+ require .NoError (t , err )
106+ assert .Equal (t , "MIT License\n " , string (content ))
121107
122108 // Second call should skip (file already exists).
123- if err := manualLicenseDownload (server .URL + "/LICENSE" , pkg ); err != nil {
124- t .Fatalf ("second manualLicenseDownload() error: %v" , err )
125- }
109+ require .NoError (t , manualLicenseDownload (client , server .URL + "/LICENSE" , pkg ))
126110}
127111
128112func TestManualLicenseDownloadHTTPError (t * testing.T ) {
@@ -136,8 +120,55 @@ func TestManualLicenseDownloadHTTPError(t *testing.T) {
136120 licensesEmbeddedDir = tmpDir
137121 t .Cleanup (func () { licensesEmbeddedDir = origDir })
138122
139- err := manualLicenseDownload (server .URL + "/LICENSE" , "example.com/missing" )
140- if err == nil {
141- t .Fatal ("expected error for 404 response, got nil" )
142- }
123+ client := & http.Client {}
124+ err := manualLicenseDownload (client , server .URL + "/LICENSE" , "example.com/missing" )
125+ assert .Error (t , err , "expected error for 404 response, got nil" )
126+ }
127+
128+ func TestNewHTTPClient_SetsUserAgent (t * testing.T ) {
129+ var gotUA string
130+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
131+ gotUA = r .Header .Get ("User-Agent" )
132+ _ , err := w .Write ([]byte ("OK" ))
133+ if err != nil {
134+ http .Error (w , err .Error (), http .StatusInternalServerError )
135+ return
136+ }
137+ }))
138+ defer server .Close ()
139+
140+ client := newHTTPClient ()
141+ resp , err := client .Get (server .URL )
142+ require .NoError (t , err )
143+ require .Equal (t , http .StatusOK , resp .StatusCode )
144+ assert .NoError (t , resp .Body .Close ())
145+
146+ assert .Equal (t , "Snyk-CLI-Build/1.0" , gotUA )
147+ }
148+
149+ func TestNewHTTPClient_RetriesOn429 (t * testing.T ) {
150+ var calls int
151+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
152+ calls ++
153+ n := calls
154+ if n <= 2 {
155+ w .Header ().Set ("Retry-After" , "0" )
156+ w .WriteHeader (http .StatusTooManyRequests )
157+ return
158+ }
159+ _ , err := w .Write ([]byte ("OK" ))
160+ if err != nil {
161+ http .Error (w , err .Error (), http .StatusInternalServerError )
162+ return
163+ }
164+ }))
165+ defer server .Close ()
166+
167+ client := newHTTPClient ()
168+ resp , err := client .Get (server .URL )
169+ require .NoError (t , err )
170+ require .Equal (t , http .StatusOK , resp .StatusCode )
171+ assert .NoError (t , resp .Body .Close ())
172+
173+ assert .GreaterOrEqual (t , calls , 3 , "expected at least 3 server calls (2 x 429 + 1 x 200)" )
143174}
0 commit comments