Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions wif/wif.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
package wif

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
Expand All @@ -29,6 +31,7 @@ type Environment string

const (
EnvGitHub Environment = "github"
EnvFlyio Environment = "flyio"
EnvAWS Environment = "aws"
EnvGCP Environment = "gcp"
EnvNone Environment = "none"
Expand All @@ -38,16 +41,19 @@ const (
// and then tries to obtain an ID token for the audience that is passed as an argument
// To detect the environment, we do it in the following intentional order:
// 1. GitHub Actions (strongest env signals; may run atop any cloud)
// 2. AWS via IMDSv2 token endpoint (does not require env vars)
// 3. GCP via metadata header semantics
// 4. AWS ECS via ECS token endpoint and env vars provided by ECS
// 5. Azure via metadata endpoint
// 2. Fly Machines via app name and machine ID env vars
// 3. AWS via IMDSv2 token endpoint (does not require env vars)
// 4. GCP via metadata header semantics
// 5. AWS ECS via ECS token endpoint and env vars provided by ECS
// 6. Azure via metadata endpoint
Comment thread
saleemrashid marked this conversation as resolved.
func ObtainProviderToken(ctx context.Context, audience string) (string, error) {
env := detectEnvironment(ctx)

switch env {
case EnvGitHub:
return acquireGitHubActionsIDToken(ctx, audience)
case EnvFlyio:
return acquireFlyioOIDCToken(ctx, audience)
case EnvAWS:
return acquireAWSWebIdentityToken(ctx, audience)
case EnvGCP:
Expand All @@ -62,6 +68,10 @@ func detectEnvironment(ctx context.Context) Environment {
os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN") != "" {
return EnvGitHub
}
if os.Getenv("FLY_APP_NAME") != "" &&
os.Getenv("FLY_MACHINE_ID") != "" {
return EnvFlyio
}

client := httpClient()
if detectAWSIMDSv2(ctx, client) {
Expand Down Expand Up @@ -166,6 +176,57 @@ func acquireGitHubActionsIDToken(ctx context.Context, audience string) (string,
return tr.Value, nil
}

type flyioOIDCRequest struct {
Audience string `json:"aud"`
}

func acquireFlyioOIDCToken(ctx context.Context, audience string) (string, error) {
body, err := json.Marshal(&flyioOIDCRequest{
Audience: strings.TrimSpace(audience),
})
if err != nil {
return "", fmt.Errorf("build flyio oidc request body: %w", err)
}

req, err := http.NewRequestWithContext(ctx, httpm.POST, "http://unix/v1/tokens/oidc", bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("build flyio oidc request: %w", err)
}
Comment thread
saleemrashid marked this conversation as resolved.
req.Header.Set("Content-Type", "application/json")

tr := &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", "/.fly/api")
},
}
defer tr.CloseIdleConnections()

client := httpClient()
client.Transport = tr
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("call flyio oidc endpoint: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode/100 != 2 {
b, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return "", fmt.Errorf("flyio oidc endpoint returned %s: %s", resp.Status, strings.TrimSpace(string(b)))
}

b, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return "", fmt.Errorf("read flyio oidc response: %w", err)
}
jwt := strings.TrimSpace(string(b))
if jwt == "" {
return "", fmt.Errorf("flyio oidc endpoint returned empty token")
}

return jwt, nil
}

func acquireAWSWebIdentityToken(ctx context.Context, audience string) (string, error) {
// LoadDefaultConfig wires up the default credential chain (incl. IMDS and ECS metadata).
cfg, err := config.LoadDefaultConfig(ctx)
Expand Down