Skip to content

Commit d411081

Browse files
xaionaro@dx.centerxaionaro@dx.center
authored andcommitted
feat(mcp): add auto-enrollment auth
1 parent adcf126 commit d411081

2 files changed

Lines changed: 110 additions & 3 deletions

File tree

cmd/jnimcp/main.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"os"
1111
"os/signal"
12+
"path/filepath"
1213
"time"
1314

1415
mcpserver "github.com/AndroidGoLab/jni-proxy/mcp"
@@ -62,6 +63,28 @@ func main() {
6263
func run(cmd *cobra.Command, _ []string) error {
6364
log := slog.New(slog.NewTextHandler(os.Stderr, nil))
6465

66+
ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt)
67+
defer cancel()
68+
69+
if flagCert == "" && flagKey == "" {
70+
configDir := flagConfigDir
71+
if configDir == "" {
72+
home, _ := os.UserHomeDir()
73+
configDir = filepath.Join(home, ".config", "jnimcp")
74+
}
75+
certPath, keyPath, caPath, err := mcpserver.AutoEnroll(ctx, flagAddr, configDir)
76+
if err != nil {
77+
log.Warn("auto-enrollment failed, continuing without mTLS", "error", err)
78+
} else {
79+
flagCert = certPath
80+
flagKey = keyPath
81+
if flagCA == "" {
82+
flagCA = caPath
83+
}
84+
log.Info("auto-enrolled with jniservice", "config_dir", configDir)
85+
}
86+
}
87+
6588
conn, err := dialGRPC()
6689
if err != nil {
6790
return fmt.Errorf("connecting to jniservice: %w", err)
@@ -70,9 +93,6 @@ func run(cmd *cobra.Command, _ []string) error {
7093

7194
srv := mcpserver.NewServer(conn, log)
7295

73-
ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt)
74-
defer cancel()
75-
7696
switch flagTransport {
7797
case "stdio":
7898
log.Info("starting MCP server", "transport", "stdio", "grpc_addr", flagAddr)

mcp/auth.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package mcp
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"crypto/tls"
7+
"encoding/hex"
8+
"fmt"
9+
"os"
10+
"path/filepath"
11+
12+
"github.com/AndroidGoLab/jni-proxy/grpc/server/certauth"
13+
pb "github.com/AndroidGoLab/jni-proxy/proto/auth"
14+
"google.golang.org/grpc"
15+
"google.golang.org/grpc/credentials"
16+
)
17+
18+
// AutoEnroll registers with jniservice if no certs exist in configDir.
19+
// It connects using TLS with InsecureSkipVerify (no client cert needed for Register).
20+
// Returns paths to the cert, key, and CA files.
21+
func AutoEnroll(ctx context.Context, addr, configDir string) (certPath, keyPath, caPath string, err error) {
22+
certPath = filepath.Join(configDir, "client.crt")
23+
keyPath = filepath.Join(configDir, "client.key")
24+
caPath = filepath.Join(configDir, "ca.crt")
25+
26+
// If all three files already exist, return their paths.
27+
if fileExists(certPath) && fileExists(keyPath) && fileExists(caPath) {
28+
return certPath, keyPath, caPath, nil
29+
}
30+
31+
// Create the config directory if it doesn't exist.
32+
if err := os.MkdirAll(configDir, 0700); err != nil {
33+
return "", "", "", fmt.Errorf("creating config dir: %w", err)
34+
}
35+
36+
// Generate a random suffix for the CN.
37+
randBytes := make([]byte, 8)
38+
if _, err := rand.Read(randBytes); err != nil {
39+
return "", "", "", fmt.Errorf("generating random bytes: %w", err)
40+
}
41+
cn := "jnimcp-" + hex.EncodeToString(randBytes)
42+
43+
// Generate EC P-256 keypair and CSR.
44+
csrPEM, keyPEM, err := certauth.GenerateCSR(cn)
45+
if err != nil {
46+
return "", "", "", fmt.Errorf("generating CSR: %w", err)
47+
}
48+
49+
// Connect to jniservice without client cert but with TLS (InsecureSkipVerify
50+
// because jniservice uses a self-signed CA).
51+
conn, err := grpc.NewClient(addr,
52+
grpc.WithTransportCredentials(credentials.NewTLS(
53+
&tls.Config{InsecureSkipVerify: true},
54+
)),
55+
)
56+
if err != nil {
57+
return "", "", "", fmt.Errorf("connecting to jniservice for enrollment: %w", err)
58+
}
59+
defer conn.Close()
60+
61+
// Call Register RPC.
62+
client := pb.NewAuthServiceClient(conn)
63+
resp, err := client.Register(ctx, &pb.RegisterRequest{
64+
CsrPem: string(csrPEM),
65+
})
66+
if err != nil {
67+
return "", "", "", fmt.Errorf("Register RPC: %w", err)
68+
}
69+
70+
// Save the returned client cert, CA cert, and the generated private key.
71+
if err := os.WriteFile(certPath, []byte(resp.GetClientCertPem()), 0600); err != nil {
72+
return "", "", "", fmt.Errorf("writing client cert: %w", err)
73+
}
74+
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
75+
return "", "", "", fmt.Errorf("writing client key: %w", err)
76+
}
77+
if err := os.WriteFile(caPath, []byte(resp.GetCaCertPem()), 0644); err != nil {
78+
return "", "", "", fmt.Errorf("writing CA cert: %w", err)
79+
}
80+
81+
return certPath, keyPath, caPath, nil
82+
}
83+
84+
func fileExists(path string) bool {
85+
_, err := os.Stat(path)
86+
return err == nil
87+
}

0 commit comments

Comments
 (0)