Skip to content

Commit 689dcdb

Browse files
committed
tracing: add enclave-specific tracing solution
1 parent a55988a commit 689dcdb

5 files changed

Lines changed: 441 additions & 0 deletions

File tree

tracing/enclave_provider.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package tracing
2+
3+
import (
4+
"context"
5+
6+
"github.com/0xsequence/nitrocontrol/enclave"
7+
"github.com/0xsequence/nsm/request"
8+
"github.com/0xsequence/nsm/response"
9+
)
10+
11+
func WrapEnclaveProvider(provider enclave.Provider) enclave.Provider {
12+
return func() (enclave.Session, error) {
13+
sess, err := provider()
14+
if err != nil {
15+
return nil, err
16+
}
17+
return &wrappedSession{Session: sess}, nil
18+
}
19+
}
20+
21+
type wrappedSession struct {
22+
enclave.Session
23+
}
24+
25+
func (w *wrappedSession) Send(ctx context.Context, req request.Request) (res response.Response, err error) {
26+
ctx, span := Trace(ctx, w.getSpanName(req))
27+
defer func() {
28+
span.RecordError(err)
29+
span.End()
30+
}()
31+
return w.Session.Send(ctx, req)
32+
}
33+
34+
func (*wrappedSession) getSpanName(req request.Request) string {
35+
switch req.(type) {
36+
case *request.DescribePCR:
37+
return "NSM.DescribePCR"
38+
case *request.ExtendPCR:
39+
return "NSM.ExtendPCR"
40+
case *request.LockPCR:
41+
return "NSM.LockPCR"
42+
case *request.LockPCRs:
43+
return "NSM.LockPCRs"
44+
case *request.DescribeNSM:
45+
return "NSM.DescribeNSM"
46+
case *request.Attestation:
47+
return "NSM.Attestation"
48+
case *request.GetRandom:
49+
return "NSM.GetRandom"
50+
}
51+
return "NSM.Send"
52+
}

tracing/http_client.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package tracing
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
type HTTPClient interface {
9+
Do(*http.Request) (*http.Response, error)
10+
Get(string) (*http.Response, error)
11+
}
12+
13+
type wrappedClient struct {
14+
HTTPClient
15+
ctx context.Context
16+
}
17+
18+
func WrapClient(c HTTPClient) HTTPClient {
19+
return &wrappedClient{HTTPClient: c}
20+
}
21+
22+
func WrapClientWithContext(ctx context.Context, c HTTPClient) HTTPClient {
23+
return &wrappedClient{HTTPClient: c, ctx: ctx}
24+
}
25+
26+
func (c *wrappedClient) Do(req *http.Request) (res *http.Response, err error) {
27+
ctx, span := Trace(req.Context(), req.URL.Host, WithSpanKind(SpanKindClient))
28+
defer func() {
29+
if err != nil {
30+
span.RecordError(err)
31+
} else {
32+
span.SetMetadata(map[string]any{
33+
"http.status_code": res.StatusCode,
34+
"http.response_content_length": res.ContentLength,
35+
})
36+
span.SetStatus(res.StatusCode)
37+
}
38+
span.End()
39+
}()
40+
41+
span.SetMetadata(map[string]any{
42+
"http.method": req.Method,
43+
"http.url": req.URL.String(),
44+
"http.scheme": req.URL.Scheme,
45+
"http.query": req.URL.RawQuery,
46+
"http.path": req.URL.Path,
47+
"http.request_content_length": req.ContentLength,
48+
})
49+
50+
return c.HTTPClient.Do(req.WithContext(ctx))
51+
}
52+
53+
func (c *wrappedClient) Get(url string) (*http.Response, error) {
54+
req, err := http.NewRequest(http.MethodGet, url, nil)
55+
if err != nil {
56+
return nil, err
57+
}
58+
59+
if c.ctx != nil {
60+
req = req.WithContext(c.ctx)
61+
}
62+
63+
return c.Do(req)
64+
}

tracing/log.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package tracing
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"log/slog"
8+
"sort"
9+
"time"
10+
)
11+
12+
type Log struct {
13+
Time time.Time `json:"time"`
14+
Level slog.Level `json:"level"`
15+
Message string `json:"msg"`
16+
Attributes map[string]any `json:"attributes,omitempty"`
17+
}
18+
19+
func (l *Log) UnmarshalJSON(data []byte) error {
20+
var rawMap map[string]json.RawMessage
21+
if err := json.Unmarshal(data, &rawMap); err != nil {
22+
return err
23+
}
24+
25+
l.Attributes = make(map[string]any)
26+
27+
for k, v := range rawMap {
28+
switch k {
29+
case "time":
30+
if err := json.Unmarshal(v, &l.Time); err != nil {
31+
return fmt.Errorf("unmarshal %q: %w", k, err)
32+
}
33+
case "msg":
34+
if err := json.Unmarshal(v, &l.Message); err != nil {
35+
return fmt.Errorf("unmarshal %q: %w", k, err)
36+
}
37+
case "level":
38+
if err := json.Unmarshal(v, &l.Level); err != nil {
39+
return fmt.Errorf("unmarshal %q: %w", k, err)
40+
}
41+
default:
42+
var value any
43+
if err := json.Unmarshal(v, &value); err != nil {
44+
return fmt.Errorf("unmarshal %q: %w", k, err)
45+
}
46+
l.Attributes[k] = value
47+
}
48+
}
49+
50+
return nil
51+
}
52+
53+
func ExtractLogs(ctx context.Context, logger *slog.Logger, span *Span) error {
54+
logs, err := span.getLogs()
55+
if err != nil {
56+
return err
57+
}
58+
59+
sort.Slice(logs, func(i, j int) bool {
60+
return logs[i].Time.Before(logs[j].Time)
61+
})
62+
63+
for _, log := range logs {
64+
attrs := make([]slog.Attr, 0, len(log.Attributes))
65+
for k, v := range log.Attributes {
66+
attrs = append(attrs, slog.Any(k, v))
67+
}
68+
69+
record := slog.Record{
70+
Level: log.Level,
71+
Message: log.Message,
72+
Time: log.Time,
73+
}
74+
record.AddAttrs(attrs...)
75+
logger.Handler().Handle(ctx, record)
76+
}
77+
78+
return nil
79+
}
80+
81+
func (s *Span) getLogs() ([]*Log, error) {
82+
logs := make([]*Log, len(s.Logs))
83+
for i, log := range s.Logs {
84+
if err := json.Unmarshal(log, &logs[i]); err != nil {
85+
return nil, err
86+
}
87+
}
88+
89+
for _, child := range s.Children {
90+
childLogs, err := child.getLogs()
91+
if err != nil {
92+
return nil, err
93+
}
94+
logs = append(logs, childLogs...)
95+
}
96+
97+
if s.Metadata != nil {
98+
if msg, ok := s.Metadata["exception.message"].(string); ok {
99+
attrs := make(map[string]any)
100+
for k, v := range s.Annotations {
101+
attrs[k] = v
102+
}
103+
logs = append(logs, &Log{
104+
Time: s.EndTime,
105+
Level: slog.LevelError,
106+
Message: msg,
107+
Attributes: attrs,
108+
})
109+
}
110+
}
111+
112+
return logs, nil
113+
}

tracing/middleware.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package tracing
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"net/http"
7+
8+
"github.com/go-chi/chi/v5/middleware"
9+
"github.com/go-chi/traceid"
10+
)
11+
12+
func Middleware(errorFn func(http.ResponseWriter, error)) func(http.Handler) http.Handler {
13+
return func(next http.Handler) http.Handler {
14+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15+
var body bytes.Buffer
16+
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
17+
ww.Tee(&body)
18+
ww.Discard()
19+
20+
tid := traceid.FromContext(r.Context())
21+
ctx, span := Trace(
22+
r.Context(),
23+
r.URL.Path,
24+
WithSpanKind(SpanKindServer),
25+
WithMetadata(map[string]any{
26+
"sequence.traceid": tid,
27+
"net.host.name": r.Host,
28+
"server.address": r.Host,
29+
"http.method": r.Method,
30+
"http.url": r.URL.String(),
31+
"url.path": r.URL.Path,
32+
"url.query": r.URL.RawQuery,
33+
}),
34+
)
35+
36+
next.ServeHTTP(ww, r.WithContext(ctx))
37+
38+
span.SetStatus(ww.Status())
39+
span.End()
40+
spanJSON, err := json.Marshal(span)
41+
if err != nil {
42+
errorFn(w, err)
43+
return
44+
}
45+
46+
w.Header().Set("X-Sequence-Span", string(spanJSON))
47+
48+
w.WriteHeader(ww.Status())
49+
if _, err := body.WriteTo(w); err != nil {
50+
errorFn(w, err)
51+
return
52+
}
53+
})
54+
}
55+
}

0 commit comments

Comments
 (0)