Skip to content

Commit 29b6757

Browse files
committed
attestation: add attestation & userdata generation middleware
1 parent 689dcdb commit 29b6757

4 files changed

Lines changed: 308 additions & 0 deletions

File tree

attestation/context.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package attestation
2+
3+
import (
4+
"context"
5+
6+
"github.com/0xsequence/nitrocontrol/enclave"
7+
)
8+
9+
type contextKeyType string
10+
11+
var contextKey = contextKeyType("attestation")
12+
13+
func FromContext(ctx context.Context) *enclave.Attestation {
14+
v, _ := ctx.Value(contextKey).(*enclave.Attestation)
15+
return v
16+
}

attestation/middleware.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package attestation
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/base64"
7+
"fmt"
8+
"io"
9+
"log/slog"
10+
"net/http"
11+
12+
"github.com/0xsequence/nitrocontrol/enclave"
13+
"github.com/0xsequence/nitrocontrol/tracing"
14+
"github.com/go-chi/chi/v5/middleware"
15+
)
16+
17+
// Middleware is an HTTP middleware that issues an attestation document request to the enclave's NSM.
18+
// The result wrapped in the Attestation type is then set in the context available to subsequent handlers.
19+
// It also sets the X-Attestation-Document HTTP header to the Base64-encoded representation of the document.
20+
//
21+
// If the HTTP request includes an X-Attestation-Nonce header, its value is sent to the NSM and included in
22+
// the final attestation document.
23+
func Middleware(enc *enclave.Enclave, errorFn func(http.ResponseWriter, error), loggerFromContextFn func(context.Context) *slog.Logger) func(http.Handler) http.Handler {
24+
runPreMiddleware := func(r *http.Request) (ctx context.Context, cancelFunc func(), err error) {
25+
ctx, span := tracing.Trace(r.Context(), "attestation.Middleware")
26+
defer func() {
27+
span.RecordError(err)
28+
span.End()
29+
}()
30+
31+
log := loggerFromContextFn(ctx)
32+
att, err := enc.GetAttestation(ctx, nil, nil)
33+
if err != nil {
34+
return nil, nil, err
35+
}
36+
37+
cancelFunc = func() {
38+
if err := att.Close(); err != nil {
39+
log.Error("failed to close attestation", "error", err)
40+
return
41+
}
42+
}
43+
44+
return context.WithValue(r.Context(), contextKey, att), cancelFunc, nil
45+
}
46+
47+
runPostMiddleware := func(w http.ResponseWriter, r *http.Request, body []byte, nonce []byte) (err error) {
48+
log := loggerFromContextFn(r.Context())
49+
ctx, span := tracing.Trace(r.Context(), "attestation.Middleware")
50+
defer func() {
51+
span.RecordError(err)
52+
span.End()
53+
}()
54+
55+
userData, err := generateUserData(r, body)
56+
if err != nil {
57+
return err
58+
}
59+
60+
att, err := enc.GetAttestation(ctx, nonce, userData)
61+
if err != nil {
62+
return err
63+
}
64+
defer func() {
65+
if err := att.Close(); err != nil {
66+
log.Error("failed to close attestation", "error", err)
67+
return
68+
}
69+
}()
70+
71+
w.Header().Set("X-Attestation-Document", base64.StdEncoding.EncodeToString(att.Document()))
72+
return nil
73+
}
74+
75+
return func(next http.Handler) http.Handler {
76+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77+
reqBody, err := io.ReadAll(r.Body)
78+
if err != nil {
79+
errorFn(w, fmt.Errorf("failed to read request body: %w", err))
80+
return
81+
}
82+
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
83+
84+
var nonce []byte
85+
if nonceVal := r.Header.Get("X-Attestation-Nonce"); nonceVal != "" {
86+
if len(nonceVal) > 32 {
87+
errorFn(w, fmt.Errorf("X-Attestation-Nonce value cannot be longer than 32"))
88+
return
89+
}
90+
if !isNonceValid(nonceVal) {
91+
errorFn(w, fmt.Errorf("X-Attestation-Nonce value contains invalid characters"))
92+
return
93+
}
94+
95+
nonce = []byte(nonceVal)
96+
}
97+
98+
ctx, cancel, err := runPreMiddleware(r)
99+
if err != nil {
100+
errorFn(w, err)
101+
return
102+
}
103+
defer cancel()
104+
105+
var body bytes.Buffer
106+
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
107+
ww.Tee(&body)
108+
ww.Discard()
109+
110+
next.ServeHTTP(ww, r.WithContext(ctx))
111+
112+
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
113+
if err := runPostMiddleware(ww, r, body.Bytes(), nonce); err != nil {
114+
errorFn(w, err)
115+
return
116+
}
117+
118+
w.WriteHeader(ww.Status())
119+
if _, err := body.WriteTo(w); err != nil {
120+
errorFn(w, err)
121+
}
122+
})
123+
}
124+
}
125+
126+
func isNonceValid(s string) bool {
127+
for i := 0; i < len(s); i++ {
128+
c := s[i]
129+
if (c >= 'a' && c <= 'z') ||
130+
(c >= 'A' && c <= 'Z') ||
131+
(c >= '0' && c <= '9') ||
132+
c == '.' || c == '_' || c == '-' || c == '/' || c == '+' || c == '=' {
133+
continue
134+
}
135+
return false
136+
}
137+
return true
138+
}

attestation/middleware_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package attestation_test
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"io"
7+
"log/slog"
8+
"net/http"
9+
"net/http/httptest"
10+
"testing"
11+
12+
"github.com/0xsequence/nitrocontrol/attestation"
13+
"github.com/0xsequence/nitrocontrol/enclave"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestMiddleware(t *testing.T) {
18+
enc, err := enclave.New(context.Background(), enclave.DummyProvider(nil), nil)
19+
require.NoError(t, err)
20+
21+
errorFn := func(w http.ResponseWriter, err error) {
22+
w.WriteHeader(http.StatusBadRequest)
23+
w.Write([]byte(err.Error()))
24+
}
25+
loggerFromContextFn := func(ctx context.Context) *slog.Logger {
26+
return slog.New(slog.DiscardHandler)
27+
}
28+
29+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30+
w.WriteHeader(http.StatusOK)
31+
})
32+
srv := httptest.NewServer(attestation.Middleware(enc, errorFn, loggerFromContextFn)(handler))
33+
defer srv.Close()
34+
35+
tests := map[string]struct {
36+
request func() (http.Header, []byte)
37+
wantStatus int
38+
wantErrText string
39+
}{
40+
"NoNonce": {
41+
wantStatus: http.StatusOK,
42+
request: func() (http.Header, []byte) {
43+
return http.Header{}, []byte("test")
44+
},
45+
},
46+
"ValidNonce": {
47+
wantStatus: http.StatusOK,
48+
request: func() (http.Header, []byte) {
49+
return http.Header{
50+
"X-Attestation-Nonce": []string{"test"},
51+
}, []byte("test")
52+
},
53+
},
54+
"InvalidNonce": {
55+
wantStatus: http.StatusBadRequest,
56+
wantErrText: "X-Attestation-Nonce value contains invalid characters",
57+
request: func() (http.Header, []byte) {
58+
return http.Header{
59+
"X-Attestation-Nonce": []string{"!@#$%^&*()"},
60+
}, []byte("test")
61+
},
62+
},
63+
"LongNonce": {
64+
wantStatus: http.StatusBadRequest,
65+
wantErrText: "X-Attestation-Nonce value cannot be longer than 32",
66+
request: func() (http.Header, []byte) {
67+
return http.Header{
68+
"X-Attestation-Nonce": []string{"test-123456789012345678901234567890123"},
69+
}, []byte("test")
70+
},
71+
},
72+
"WhitespaceNonce": {
73+
wantStatus: http.StatusBadRequest,
74+
wantErrText: "X-Attestation-Nonce value contains invalid characters",
75+
request: func() (http.Header, []byte) {
76+
return http.Header{
77+
"X-Attestation-Nonce": []string{" \t a a \t "},
78+
}, []byte("test")
79+
},
80+
},
81+
}
82+
83+
for name, test := range tests {
84+
t.Run(name, func(t *testing.T) {
85+
header, reqBody := test.request()
86+
req, err := http.NewRequest("POST", srv.URL, bytes.NewBuffer(reqBody))
87+
require.NoError(t, err)
88+
req.Header = header
89+
90+
resp, err := srv.Client().Do(req)
91+
require.NoError(t, err)
92+
defer func() {
93+
require.NoError(t, resp.Body.Close())
94+
}()
95+
body, _ := io.ReadAll(resp.Body)
96+
97+
require.Equal(t, test.wantStatus, resp.StatusCode, string(body))
98+
if test.wantErrText != "" {
99+
require.Contains(t, string(body), test.wantErrText)
100+
}
101+
})
102+
}
103+
}

attestation/userdata.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package attestation
2+
3+
import (
4+
"bytes"
5+
"crypto/sha256"
6+
"encoding/base64"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
)
11+
12+
const (
13+
_UserDataPrefix = "Sequence"
14+
_UserDataVersion = 1
15+
)
16+
17+
type userData struct {
18+
Prefix string
19+
Version int
20+
Hash []byte
21+
}
22+
23+
func (u *userData) String() string {
24+
return fmt.Sprintf("%s/%d:%s", u.Prefix, u.Version, base64.StdEncoding.EncodeToString(u.Hash))
25+
}
26+
27+
func generateUserData(r *http.Request, resBody []byte) ([]byte, error) {
28+
hasher := sha256.New()
29+
hasher.Write([]byte(r.Method + " " + r.URL.Path + "\n"))
30+
31+
var reqBody []byte
32+
var err error
33+
if r.Body != nil {
34+
reqBody, err = io.ReadAll(r.Body)
35+
if err != nil {
36+
return nil, fmt.Errorf("failed to read request body: %w", err)
37+
}
38+
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
39+
hasher.Write(reqBody)
40+
}
41+
hasher.Write([]byte("\n"))
42+
hasher.Write(resBody)
43+
hash := hasher.Sum(nil)
44+
45+
userData := &userData{
46+
Prefix: _UserDataPrefix,
47+
Version: _UserDataVersion,
48+
Hash: hash,
49+
}
50+
return []byte(userData.String()), nil
51+
}

0 commit comments

Comments
 (0)