Skip to content

Commit 9bcdbab

Browse files
committed
wif: Add support for Fly Machines OIDC
1 parent 6f0ca94 commit 9bcdbab

1 file changed

Lines changed: 66 additions & 4 deletions

File tree

wif/wif.go

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
package wif
77

88
import (
9+
"bytes"
910
"context"
1011
"encoding/json"
1112
"errors"
1213
"fmt"
1314
"io"
15+
"net"
1416
"net/http"
1517
"net/url"
1618
"os"
@@ -29,6 +31,7 @@ type Environment string
2931

3032
const (
3133
EnvGitHub Environment = "github"
34+
EnvFlyio Environment = "flyio"
3235
EnvAWS Environment = "aws"
3336
EnvGCP Environment = "gcp"
3437
EnvNone Environment = "none"
@@ -38,16 +41,19 @@ const (
3841
// and then tries to obtain an ID token for the audience that is passed as an argument
3942
// To detect the environment, we do it in the following intentional order:
4043
// 1. GitHub Actions (strongest env signals; may run atop any cloud)
41-
// 2. AWS via IMDSv2 token endpoint (does not require env vars)
42-
// 3. GCP via metadata header semantics
43-
// 4. AWS ECS via ECS token endpoint and env vars provided by ECS
44-
// 5. Azure via metadata endpoint
44+
// 2. Fly Machines via app name and machine ID env vars
45+
// 3. AWS via IMDSv2 token endpoint (does not require env vars)
46+
// 4. GCP via metadata header semantics
47+
// 5. AWS ECS via ECS token endpoint and env vars provided by ECS
48+
// 6. Azure via metadata endpoint
4549
func ObtainProviderToken(ctx context.Context, audience string) (string, error) {
4650
env := detectEnvironment(ctx)
4751

4852
switch env {
4953
case EnvGitHub:
5054
return acquireGitHubActionsIDToken(ctx, audience)
55+
case EnvFlyio:
56+
return acquireFlyioOIDCToken(ctx, audience)
5157
case EnvAWS:
5258
return acquireAWSWebIdentityToken(ctx, audience)
5359
case EnvGCP:
@@ -62,6 +68,10 @@ func detectEnvironment(ctx context.Context) Environment {
6268
os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") != "" {
6369
return EnvGitHub
6470
}
71+
if os.Getenv("FLY_APP_NAME") != "" &&
72+
os.Getenv("FLY_MACHINE_ID") != "" {
73+
return EnvFlyio
74+
}
6575

6676
client := httpClient()
6777
if detectAWSIMDSv2(ctx, client) {
@@ -166,6 +176,58 @@ func acquireGitHubActionsIDToken(ctx context.Context, audience string) (string,
166176
return tr.Value, nil
167177
}
168178

179+
type flyioOIDCRequest struct {
180+
Audience string `json:"aud"`
181+
}
182+
183+
func acquireFlyioOIDCToken(ctx context.Context, audience string) (string, error) {
184+
body, err := json.Marshal(&flyioOIDCRequest{
185+
Audience: strings.TrimSpace(audience),
186+
})
187+
if err != nil {
188+
return "", fmt.Errorf("build Fly OIDC request body: %w", err)
189+
}
190+
191+
req, err := http.NewRequestWithContext(ctx, httpm.POST, "http://unix/v1/tokens/oidc", bytes.NewReader(body))
192+
if err != nil {
193+
return "", fmt.Errorf("build Fly OIDC request: %w", err)
194+
}
195+
req.Header.Set("Content-Type", "application/json")
196+
197+
tr := &http.Transport{
198+
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
199+
var d net.Dialer
200+
return d.DialContext(ctx, "unix", "/.fly/api")
201+
},
202+
}
203+
defer tr.CloseIdleConnections()
204+
205+
client := httpClient()
206+
client.Transport = tr
207+
208+
resp, err := client.Do(req)
209+
if err != nil {
210+
return "", fmt.Errorf("call Fly OIDC endpoint: %w", err)
211+
}
212+
defer resp.Body.Close()
213+
214+
if resp.StatusCode/100 != 2 {
215+
b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
216+
return "", fmt.Errorf("Fly OIDC endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b)))
217+
}
218+
219+
b, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
220+
if err != nil {
221+
return "", fmt.Errorf("read Fly OIDC response: %w", err)
222+
}
223+
jwt := strings.TrimSpace(string(b))
224+
if jwt == "" {
225+
return "", fmt.Errorf("Fly OIDC endpoint returned empty token")
226+
}
227+
228+
return jwt, nil
229+
}
230+
169231
func acquireAWSWebIdentityToken(ctx context.Context, audience string) (string, error) {
170232
// LoadDefaultConfig wires up the default credential chain (incl. IMDS and ECS metadata).
171233
cfg, err := config.LoadDefaultConfig(ctx)

0 commit comments

Comments
 (0)