66package wif
77
88import (
9+ "bytes"
910 "context"
1011 "encoding/json"
1112 "errors"
1213 "fmt"
1314 "io"
15+ "net"
1416 "net/http"
1517 "net/url"
1618 "os"
@@ -29,6 +31,7 @@ type Environment string
2931
3032const (
3133 EnvGitHub Environment = "github"
34+ EnvFlyio Environment = "flyio"
3235 EnvAWS Environment = "aws"
3336 EnvGCP Environment = "gcp"
3437 EnvNone Environment = "none"
@@ -38,16 +41,19 @@ const (
3841// and then tries to obtain an ID token for the audience that is passed as an argument
3942// To detect the environment, we do it in the following intentional order:
4043// 1. GitHub Actions (strongest env signals; may run atop any cloud)
41- // 2. AWS via IMDSv2 token endpoint (does not require env vars)
42- // 3. GCP via metadata header semantics
43- // 4. AWS ECS via ECS token endpoint and env vars provided by ECS
44- // 5. Azure via metadata endpoint
44+ // 2. Fly Machines via app name and machine ID env vars
45+ // 3. AWS via IMDSv2 token endpoint (does not require env vars)
46+ // 4. GCP via metadata header semantics
47+ // 5. AWS ECS via ECS token endpoint and env vars provided by ECS
48+ // 6. Azure via metadata endpoint
4549func ObtainProviderToken (ctx context.Context , audience string ) (string , error ) {
4650 env := detectEnvironment (ctx )
4751
4852 switch env {
4953 case EnvGitHub :
5054 return acquireGitHubActionsIDToken (ctx , audience )
55+ case EnvFlyio :
56+ return acquireFlyioOIDCToken (ctx , audience )
5157 case EnvAWS :
5258 return acquireAWSWebIdentityToken (ctx , audience )
5359 case EnvGCP :
@@ -62,6 +68,10 @@ func detectEnvironment(ctx context.Context) Environment {
6268 os .Getenv ("ACTIONS_ID_TOKEN_REQUEST_TOKEN" ) != "" {
6369 return EnvGitHub
6470 }
71+ if os .Getenv ("FLY_APP_NAME" ) != "" &&
72+ os .Getenv ("FLY_MACHINE_ID" ) != "" {
73+ return EnvFlyio
74+ }
6575
6676 client := httpClient ()
6777 if detectAWSIMDSv2 (ctx , client ) {
@@ -166,6 +176,58 @@ func acquireGitHubActionsIDToken(ctx context.Context, audience string) (string,
166176 return tr .Value , nil
167177}
168178
179+ type flyioOIDCRequest struct {
180+ Audience string `json:"aud"`
181+ }
182+
183+ func acquireFlyioOIDCToken (ctx context.Context , audience string ) (string , error ) {
184+ body , err := json .Marshal (& flyioOIDCRequest {
185+ Audience : strings .TrimSpace (audience ),
186+ })
187+ if err != nil {
188+ return "" , fmt .Errorf ("build Fly OIDC request body: %w" , err )
189+ }
190+
191+ req , err := http .NewRequestWithContext (ctx , httpm .POST , "http://unix/v1/tokens/oidc" , bytes .NewReader (body ))
192+ if err != nil {
193+ return "" , fmt .Errorf ("build Fly OIDC request: %w" , err )
194+ }
195+ req .Header .Set ("Content-Type" , "application/json" )
196+
197+ tr := & http.Transport {
198+ DialContext : func (ctx context.Context , _ , _ string ) (net.Conn , error ) {
199+ var d net.Dialer
200+ return d .DialContext (ctx , "unix" , "/.fly/api" )
201+ },
202+ }
203+ defer tr .CloseIdleConnections ()
204+
205+ client := httpClient ()
206+ client .Transport = tr
207+
208+ resp , err := client .Do (req )
209+ if err != nil {
210+ return "" , fmt .Errorf ("call Fly OIDC endpoint: %w" , err )
211+ }
212+ defer resp .Body .Close ()
213+
214+ if resp .StatusCode / 100 != 2 {
215+ b , _ := io .ReadAll (io .LimitReader (resp .Body , 2048 ))
216+ return "" , fmt .Errorf ("Fly OIDC endpoint returned %s: %s" , resp .Status , strings .TrimSpace (string (b )))
217+ }
218+
219+ b , err := io .ReadAll (io .LimitReader (resp .Body , 1024 * 1024 ))
220+ if err != nil {
221+ return "" , fmt .Errorf ("read Fly OIDC response: %w" , err )
222+ }
223+ jwt := strings .TrimSpace (string (b ))
224+ if jwt == "" {
225+ return "" , fmt .Errorf ("Fly OIDC endpoint returned empty token" )
226+ }
227+
228+ return jwt , nil
229+ }
230+
169231func acquireAWSWebIdentityToken (ctx context.Context , audience string ) (string , error ) {
170232 // LoadDefaultConfig wires up the default credential chain (incl. IMDS and ECS metadata).
171233 cfg , err := config .LoadDefaultConfig (ctx )
0 commit comments