|
1 | 1 | package handler |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bufio" |
| 5 | + "encoding/json" |
4 | 6 | "fmt" |
5 | 7 | "io" |
6 | 8 | "net/http" |
7 | 9 | "strings" |
| 10 | + "time" |
| 11 | + |
| 12 | + "github.com/git-pkgs/purl" |
8 | 13 | ) |
9 | 14 |
|
10 | 15 | const ( |
@@ -41,7 +46,7 @@ func (h *GemHandler) Routes() http.Handler { |
41 | 46 |
|
42 | 47 | // Compact index (bundler 2.x+) |
43 | 48 | mux.HandleFunc("GET /versions", h.proxyUpstream) |
44 | | - mux.HandleFunc("GET /info/{name}", h.proxyUpstream) |
| 49 | + mux.HandleFunc("GET /info/{name}", h.handleCompactIndex) |
45 | 50 |
|
46 | 51 | // Quick index |
47 | 52 | mux.HandleFunc("GET /quick/Marshal.4.8/{filename}", h.proxyUpstream) |
@@ -98,6 +103,191 @@ func (h *GemHandler) parseGemFilename(filename string) (name, version string) { |
98 | 103 | return "", "" |
99 | 104 | } |
100 | 105 |
|
| 106 | +// handleCompactIndex serves the compact index for a gem, filtering versions |
| 107 | +// based on cooldown when enabled. |
| 108 | +func (h *GemHandler) handleCompactIndex(w http.ResponseWriter, r *http.Request) { |
| 109 | + if h.proxy.Cooldown == nil || !h.proxy.Cooldown.Enabled() { |
| 110 | + h.proxyUpstream(w, r) |
| 111 | + return |
| 112 | + } |
| 113 | + |
| 114 | + name := r.PathValue("name") |
| 115 | + if name == "" { |
| 116 | + http.Error(w, "invalid gem name", http.StatusBadRequest) |
| 117 | + return |
| 118 | + } |
| 119 | + |
| 120 | + h.proxy.Logger.Info("gem compact index request with cooldown", "name", name) |
| 121 | + |
| 122 | + indexResp, filteredVersions, err := h.fetchIndexAndVersions(r, name) |
| 123 | + if err != nil { |
| 124 | + h.proxy.Logger.Error("upstream compact index request failed", "error", err) |
| 125 | + http.Error(w, "upstream request failed", http.StatusBadGateway) |
| 126 | + return |
| 127 | + } |
| 128 | + defer func() { _ = indexResp.Body.Close() }() |
| 129 | + |
| 130 | + if indexResp.StatusCode != http.StatusOK { |
| 131 | + copyResponseHeaders(w, indexResp.Header) |
| 132 | + w.WriteHeader(indexResp.StatusCode) |
| 133 | + _, _ = io.Copy(w, indexResp.Body) |
| 134 | + return |
| 135 | + } |
| 136 | + |
| 137 | + if filteredVersions == nil { |
| 138 | + h.proxy.Logger.Warn("failed to fetch version timestamps, proxying unfiltered", "name", name) |
| 139 | + copyResponseHeaders(w, indexResp.Header) |
| 140 | + w.WriteHeader(http.StatusOK) |
| 141 | + _, _ = io.Copy(w, indexResp.Body) |
| 142 | + return |
| 143 | + } |
| 144 | + |
| 145 | + h.writeFilteredIndex(w, indexResp, name, filteredVersions) |
| 146 | +} |
| 147 | + |
| 148 | +// fetchIndexAndVersions fetches the compact index and versions API concurrently. |
| 149 | +// Returns the index response, a set of versions to filter (nil if versions API failed), |
| 150 | +// and an error if the index fetch itself failed. |
| 151 | +func (h *GemHandler) fetchIndexAndVersions(r *http.Request, name string) (*http.Response, map[string]bool, error) { |
| 152 | + type versionsResult struct { |
| 153 | + filtered map[string]bool |
| 154 | + err error |
| 155 | + } |
| 156 | + |
| 157 | + versionsCh := make(chan versionsResult, 1) |
| 158 | + go func() { |
| 159 | + filtered, err := h.fetchFilteredVersions(r, name) |
| 160 | + versionsCh <- versionsResult{filtered: filtered, err: err} |
| 161 | + }() |
| 162 | + |
| 163 | + indexResp, err := h.fetchCompactIndex(r, name) |
| 164 | + |
| 165 | + versionsRes := <-versionsCh |
| 166 | + |
| 167 | + if err != nil { |
| 168 | + return nil, nil, err |
| 169 | + } |
| 170 | + |
| 171 | + if versionsRes.err != nil { |
| 172 | + return indexResp, nil, nil |
| 173 | + } |
| 174 | + |
| 175 | + return indexResp, versionsRes.filtered, nil |
| 176 | +} |
| 177 | + |
| 178 | +// fetchCompactIndex fetches the compact index from upstream. |
| 179 | +func (h *GemHandler) fetchCompactIndex(r *http.Request, name string) (*http.Response, error) { |
| 180 | + indexURL := h.upstreamURL + "/info/" + name |
| 181 | + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, indexURL, nil) |
| 182 | + if err != nil { |
| 183 | + return nil, err |
| 184 | + } |
| 185 | + for _, hdr := range []string{"Accept", "Accept-Encoding", "If-None-Match", "If-Modified-Since"} { |
| 186 | + if v := r.Header.Get(hdr); v != "" { |
| 187 | + req.Header.Set(hdr, v) |
| 188 | + } |
| 189 | + } |
| 190 | + return h.proxy.HTTPClient.Do(req) |
| 191 | +} |
| 192 | + |
| 193 | +// writeFilteredIndex writes the compact index response with cooldown-filtered versions removed. |
| 194 | +func (h *GemHandler) writeFilteredIndex(w http.ResponseWriter, resp *http.Response, name string, filtered map[string]bool) { |
| 195 | + for k, vv := range resp.Header { |
| 196 | + if strings.EqualFold(k, "Content-Length") { |
| 197 | + continue // length will change after filtering |
| 198 | + } |
| 199 | + for _, v := range vv { |
| 200 | + w.Header().Add(k, v) |
| 201 | + } |
| 202 | + } |
| 203 | + w.WriteHeader(http.StatusOK) |
| 204 | + |
| 205 | + scanner := bufio.NewScanner(resp.Body) |
| 206 | + for scanner.Scan() { |
| 207 | + line := scanner.Text() |
| 208 | + |
| 209 | + if line == "---" { |
| 210 | + _, _ = fmt.Fprintln(w, line) |
| 211 | + continue |
| 212 | + } |
| 213 | + |
| 214 | + version := line |
| 215 | + if spaceIdx := strings.IndexByte(line, ' '); spaceIdx > 0 { |
| 216 | + version = line[:spaceIdx] |
| 217 | + } |
| 218 | + |
| 219 | + if filtered[version] { |
| 220 | + h.proxy.Logger.Info("cooldown: filtering gem version", |
| 221 | + "gem", name, "version", version) |
| 222 | + continue |
| 223 | + } |
| 224 | + |
| 225 | + _, _ = fmt.Fprintln(w, line) |
| 226 | + } |
| 227 | +} |
| 228 | + |
| 229 | +// copyResponseHeaders copies HTTP headers from a response to a writer. |
| 230 | +func copyResponseHeaders(w http.ResponseWriter, headers http.Header) { |
| 231 | + for k, vv := range headers { |
| 232 | + for _, v := range vv { |
| 233 | + w.Header().Add(k, v) |
| 234 | + } |
| 235 | + } |
| 236 | +} |
| 237 | + |
| 238 | +// gemVersion represents a version entry from the RubyGems versions API. |
| 239 | +type gemVersion struct { |
| 240 | + Number string `json:"number"` |
| 241 | + Platform string `json:"platform"` |
| 242 | + CreatedAt string `json:"created_at"` |
| 243 | +} |
| 244 | + |
| 245 | +// fetchFilteredVersions fetches the versions API and returns a set of version |
| 246 | +// strings that should be filtered out by cooldown. |
| 247 | +func (h *GemHandler) fetchFilteredVersions(r *http.Request, name string) (map[string]bool, error) { |
| 248 | + versionsURL := fmt.Sprintf("%s/api/v1/versions/%s.json", h.upstreamURL, name) |
| 249 | + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, versionsURL, nil) |
| 250 | + if err != nil { |
| 251 | + return nil, err |
| 252 | + } |
| 253 | + |
| 254 | + resp, err := h.proxy.HTTPClient.Do(req) |
| 255 | + if err != nil { |
| 256 | + return nil, err |
| 257 | + } |
| 258 | + defer func() { _ = resp.Body.Close() }() |
| 259 | + |
| 260 | + if resp.StatusCode != http.StatusOK { |
| 261 | + return nil, fmt.Errorf("versions API returned %d", resp.StatusCode) |
| 262 | + } |
| 263 | + |
| 264 | + var versions []gemVersion |
| 265 | + if err := json.NewDecoder(resp.Body).Decode(&versions); err != nil { |
| 266 | + return nil, err |
| 267 | + } |
| 268 | + |
| 269 | + packagePURL := purl.MakePURLString("gem", name, "") |
| 270 | + filtered := make(map[string]bool) |
| 271 | + |
| 272 | + for _, v := range versions { |
| 273 | + createdAt, err := time.Parse(time.RFC3339, v.CreatedAt) |
| 274 | + if err != nil { |
| 275 | + continue |
| 276 | + } |
| 277 | + |
| 278 | + if !h.proxy.Cooldown.IsAllowed("gem", packagePURL, createdAt) { |
| 279 | + // Build version string matching compact index format |
| 280 | + versionStr := v.Number |
| 281 | + if v.Platform != "" && v.Platform != "ruby" { |
| 282 | + versionStr = v.Number + "-" + v.Platform |
| 283 | + } |
| 284 | + filtered[versionStr] = true |
| 285 | + } |
| 286 | + } |
| 287 | + |
| 288 | + return filtered, nil |
| 289 | +} |
| 290 | + |
101 | 291 | // proxyUpstream forwards a request to rubygems.org without caching. |
102 | 292 | func (h *GemHandler) proxyUpstream(w http.ResponseWriter, r *http.Request) { |
103 | 293 | upstreamURL := h.upstreamURL + r.URL.Path |
|
0 commit comments