Skip to content

Commit 86a2245

Browse files
authored
Merge pull request #11 from hackertron/feat/gateway-server
feat: add HTTP/WebSocket gateway server with REST API and streaming
2 parents 8a8b0d2 + 08232af commit 86a2245

15 files changed

Lines changed: 1657 additions & 15 deletions

File tree

cmd/yantra/main.go

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"path/filepath"
1010
"syscall"
1111

12+
"github.com/hackertron/Yantra/internal/gateway"
1213
"github.com/hackertron/Yantra/internal/memory"
1314
"github.com/hackertron/Yantra/internal/provider"
1415
"github.com/hackertron/Yantra/internal/runtime"
@@ -307,7 +308,74 @@ func runTUI(cmd *cobra.Command, args []string) error {
307308
}
308309

309310
func runServe(cmd *cobra.Command, args []string) error {
310-
fmt.Println("Starting Yantra API server...")
311-
// TODO: implement API server
312-
return fmt.Errorf("not yet implemented")
311+
ctx := cmd.Context()
312+
logger := slog.Default()
313+
314+
cfg, err := types.LoadConfig(configPath)
315+
if err != nil {
316+
return fmt.Errorf("loading config: %w", err)
317+
}
318+
319+
p, err := provider.BuildFromConfig(cfg)
320+
if err != nil {
321+
return fmt.Errorf("building provider: %w", err)
322+
}
323+
p = provider.NewReliable(p, provider.DefaultReliableConfig())
324+
325+
absWorkspace, err := filepath.Abs(".")
326+
if err != nil {
327+
return fmt.Errorf("resolving workspace: %w", err)
328+
}
329+
330+
// Set up memory if enabled.
331+
var mem types.MemoryRetrieval
332+
var memDB *memory.DB
333+
var sessStore types.SessionStore
334+
335+
if cfg.Memory.Enabled {
336+
dbPath := cfg.Memory.DBPath
337+
if dbPath == "" {
338+
dbPath = ".yantra/memory.db"
339+
}
340+
if !filepath.IsAbs(dbPath) {
341+
dbPath = filepath.Join(absWorkspace, dbPath)
342+
}
343+
344+
memDB, err = memory.OpenDB(dbPath)
345+
if err != nil {
346+
slog.Warn("failed to open memory DB, continuing without memory", "error", err)
347+
} else {
348+
embedder, err := memory.NewEmbeddingBackend(cfg.Memory)
349+
if err != nil {
350+
slog.Warn("failed to create embedding backend, continuing without embeddings", "error", err)
351+
}
352+
mem = memory.NewStore(memDB, embedder, cfg.Memory.Retrieval)
353+
sessStore = memory.NewSessionStore(memDB)
354+
}
355+
}
356+
if memDB != nil {
357+
defer memDB.Close()
358+
}
359+
360+
// Sessions are required for the gateway even when memory is disabled.
361+
// Fall back to an in-memory SQLite DB for session tracking only.
362+
if sessStore == nil {
363+
sessDB, err := memory.OpenDB(":memory:")
364+
if err != nil {
365+
return fmt.Errorf("opening session DB: %w", err)
366+
}
367+
defer sessDB.Close()
368+
sessStore = memory.NewSessionStore(sessDB)
369+
}
370+
371+
policy := tool.NewWorkspacePolicy(cfg.Tools.Shell)
372+
reg := tool.NewRegistry(policy)
373+
if err := tool.RegisterBuiltins(reg, cfg.Tools, mem); err != nil {
374+
return fmt.Errorf("registering tools: %w", err)
375+
}
376+
377+
srv := gateway.NewServer(cfg.Gateway, cfg, p, reg, mem, sessStore, absWorkspace, logger)
378+
379+
logger.Info("starting Yantra API server", "listen", cfg.Gateway.Listen)
380+
return srv.ListenAndServe(ctx)
313381
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ go 1.26
44

55
require (
66
github.com/anthropics/anthropic-sdk-go v1.26.0
7+
github.com/gorilla/websocket v1.5.3
78
github.com/knadh/koanf/parsers/toml v0.1.0
89
github.com/knadh/koanf/providers/env v1.1.0
910
github.com/knadh/koanf/providers/file v1.2.1
@@ -28,7 +29,6 @@ require (
2829
github.com/google/s2a-go v0.1.8 // indirect
2930
github.com/google/uuid v1.6.0 // indirect
3031
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
31-
github.com/gorilla/websocket v1.5.3 // indirect
3232
github.com/inconshreveable/mousetrap v1.1.0 // indirect
3333
github.com/knadh/koanf/maps v0.1.2 // indirect
3434
github.com/mattn/go-isatty v0.0.20 // indirect

internal/gateway/auth.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package gateway
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
)
7+
8+
// authMiddleware rejects requests without a valid Bearer token.
9+
// If the server's APIKey is empty (dev mode), all requests pass through.
10+
func (s *Server) authMiddleware(next http.Handler) http.Handler {
11+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
12+
if s.cfg.APIKey == "" {
13+
next.ServeHTTP(w, r)
14+
return
15+
}
16+
17+
auth := r.Header.Get("Authorization")
18+
if auth == "" {
19+
http.Error(w, `{"error":"missing authorization header"}`, http.StatusUnauthorized)
20+
return
21+
}
22+
23+
token := strings.TrimPrefix(auth, "Bearer ")
24+
if token == auth || !s.validateAPIKey(token) {
25+
http.Error(w, `{"error":"invalid api key"}`, http.StatusUnauthorized)
26+
return
27+
}
28+
29+
next.ServeHTTP(w, r)
30+
})
31+
}
32+
33+
// validateAPIKey checks whether key matches the configured API key.
34+
func (s *Server) validateAPIKey(key string) bool {
35+
return s.cfg.APIKey != "" && key == s.cfg.APIKey
36+
}

internal/gateway/auth_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package gateway
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/hackertron/Yantra/internal/types"
9+
)
10+
11+
func TestAuthMiddleware_EmptyKey_PassesAll(t *testing.T) {
12+
s := &Server{cfg: types.GatewayConfig{APIKey: ""}}
13+
14+
handler := s.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
15+
w.WriteHeader(http.StatusOK)
16+
}))
17+
18+
req := httptest.NewRequest("GET", "/api/v1/sessions", nil)
19+
rec := httptest.NewRecorder()
20+
handler.ServeHTTP(rec, req)
21+
22+
if rec.Code != http.StatusOK {
23+
t.Errorf("expected 200, got %d", rec.Code)
24+
}
25+
}
26+
27+
func TestAuthMiddleware_CorrectToken(t *testing.T) {
28+
s := &Server{cfg: types.GatewayConfig{APIKey: "secret-key"}}
29+
30+
handler := s.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31+
w.WriteHeader(http.StatusOK)
32+
}))
33+
34+
req := httptest.NewRequest("GET", "/api/v1/sessions", nil)
35+
req.Header.Set("Authorization", "Bearer secret-key")
36+
rec := httptest.NewRecorder()
37+
handler.ServeHTTP(rec, req)
38+
39+
if rec.Code != http.StatusOK {
40+
t.Errorf("expected 200, got %d", rec.Code)
41+
}
42+
}
43+
44+
func TestAuthMiddleware_WrongToken(t *testing.T) {
45+
s := &Server{cfg: types.GatewayConfig{APIKey: "secret-key"}}
46+
47+
handler := s.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
48+
w.WriteHeader(http.StatusOK)
49+
}))
50+
51+
req := httptest.NewRequest("GET", "/api/v1/sessions", nil)
52+
req.Header.Set("Authorization", "Bearer wrong-key")
53+
rec := httptest.NewRecorder()
54+
handler.ServeHTTP(rec, req)
55+
56+
if rec.Code != http.StatusUnauthorized {
57+
t.Errorf("expected 401, got %d", rec.Code)
58+
}
59+
}
60+
61+
func TestAuthMiddleware_MissingHeader(t *testing.T) {
62+
s := &Server{cfg: types.GatewayConfig{APIKey: "secret-key"}}
63+
64+
handler := s.authMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65+
w.WriteHeader(http.StatusOK)
66+
}))
67+
68+
req := httptest.NewRequest("GET", "/api/v1/sessions", nil)
69+
rec := httptest.NewRecorder()
70+
handler.ServeHTTP(rec, req)
71+
72+
if rec.Code != http.StatusUnauthorized {
73+
t.Errorf("expected 401, got %d", rec.Code)
74+
}
75+
}

internal/gateway/health.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package gateway
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
)
7+
8+
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
9+
w.Header().Set("Content-Type", "application/json")
10+
json.NewEncoder(w).Encode(map[string]string{
11+
"status": "ok",
12+
"version": "0.1.0",
13+
})
14+
}
15+
16+
func (s *Server) handleReady(w http.ResponseWriter, r *http.Request) {
17+
w.Header().Set("Content-Type", "application/json")
18+
json.NewEncoder(w).Encode(map[string]bool{
19+
"ready": true,
20+
})
21+
}

internal/gateway/health_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package gateway
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/hackertron/Yantra/internal/types"
10+
)
11+
12+
func TestHandleHealth(t *testing.T) {
13+
s := &Server{cfg: types.GatewayConfig{}}
14+
15+
req := httptest.NewRequest("GET", "/health", nil)
16+
rec := httptest.NewRecorder()
17+
s.handleHealth(rec, req)
18+
19+
if rec.Code != http.StatusOK {
20+
t.Fatalf("expected 200, got %d", rec.Code)
21+
}
22+
23+
var body map[string]string
24+
if err := json.NewDecoder(rec.Body).Decode(&body); err != nil {
25+
t.Fatalf("decode error: %v", err)
26+
}
27+
if body["status"] != "ok" {
28+
t.Errorf("expected status=ok, got %q", body["status"])
29+
}
30+
if body["version"] != "0.1.0" {
31+
t.Errorf("expected version=0.1.0, got %q", body["version"])
32+
}
33+
}
34+
35+
func TestHandleReady(t *testing.T) {
36+
s := &Server{cfg: types.GatewayConfig{}}
37+
38+
req := httptest.NewRequest("GET", "/ready", nil)
39+
rec := httptest.NewRecorder()
40+
s.handleReady(rec, req)
41+
42+
if rec.Code != http.StatusOK {
43+
t.Fatalf("expected 200, got %d", rec.Code)
44+
}
45+
46+
var body map[string]bool
47+
if err := json.NewDecoder(rec.Body).Decode(&body); err != nil {
48+
t.Fatalf("decode error: %v", err)
49+
}
50+
if !body["ready"] {
51+
t.Error("expected ready=true")
52+
}
53+
}

0 commit comments

Comments
 (0)