Skip to content

Commit 08232af

Browse files
hackertronclaude
andcommitted
fix: address PR review — nil sessStore, error masking, WS origin check
1. Nil session store: runServe now falls back to an in-memory SQLite DB for session tracking when memory is disabled, preventing nil panics. 2. GetOrCreate hides errors: Added ErrSessionNotFound sentinel. The session store's Get method now returns it for sql.ErrNoRows and surfaces real DB errors. GetOrCreate only creates a new session on ErrSessionNotFound, propagating other failures to the caller. 3. WS origin unrestricted: Replaced the permissive CheckOrigin with a localhost-only check when no API key is configured, preventing cross-site WebSocket hijacking of a local gateway. When an API key is set, all origins are allowed since the hello frame authenticates. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7c5aab1 commit 08232af

5 files changed

Lines changed: 59 additions & 12 deletions

File tree

cmd/yantra/main.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,17 @@ func runServe(cmd *cobra.Command, args []string) error {
357357
defer memDB.Close()
358358
}
359359

360+
// Sessions are required for the gateway even when memory is disabled.
361+
// Fall back to an in-memory SQLite DB for session tracking only.
362+
if sessStore == nil {
363+
sessDB, err := memory.OpenDB(":memory:")
364+
if err != nil {
365+
return fmt.Errorf("opening session DB: %w", err)
366+
}
367+
defer sessDB.Close()
368+
sessStore = memory.NewSessionStore(sessDB)
369+
}
370+
360371
policy := tool.NewWorkspacePolicy(cfg.Tools.Shell)
361372
reg := tool.NewRegistry(policy)
362373
if err := tool.RegisterBuiltins(reg, cfg.Tools, mem); err != nil {

internal/gateway/session.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package gateway
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"sync"
78
"time"
@@ -74,6 +75,9 @@ func (sm *SessionManager) GetOrCreate(ctx context.Context, sessionID string) (*M
7475
// Look up or create the persistent session record.
7576
rec, err := sm.server.sessStore.Get(ctx, sessionID)
7677
if err != nil {
78+
if !errors.Is(err, types.ErrSessionNotFound) {
79+
return nil, fmt.Errorf("looking up session: %w", err)
80+
}
7781
// Not found — create a new record.
7882
rec, err = sm.server.sessStore.Create(ctx, "gateway-session")
7983
if err != nil {

internal/gateway/ws.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,13 @@ import (
55
"encoding/json"
66
"log/slog"
77
"net/http"
8+
"net/url"
89
"sync"
910

1011
"github.com/gorilla/websocket"
1112
"github.com/hackertron/Yantra/internal/types"
1213
)
1314

14-
var upgrader = websocket.Upgrader{
15-
ReadBufferSize: 4096,
16-
WriteBufferSize: 4096,
17-
CheckOrigin: func(r *http.Request) bool { return true }, // permissive for dev
18-
}
19-
2015
// wsConn wraps a single WebSocket connection.
2116
type wsConn struct {
2217
conn *websocket.Conn
@@ -26,8 +21,39 @@ type wsConn struct {
2621
writeMu sync.Mutex // gorilla writes are not concurrent-safe
2722
}
2823

24+
// newUpgrader creates a WebSocket upgrader with origin checking appropriate
25+
// to the server's security configuration. When an API key is set, all
26+
// origins are allowed (the hello frame authenticates). When no API key is
27+
// set (dev/local mode), only localhost origins are permitted to prevent
28+
// cross-site WebSocket hijacking of a local gateway.
29+
func (s *Server) newUpgrader() websocket.Upgrader {
30+
return websocket.Upgrader{
31+
ReadBufferSize: 4096,
32+
WriteBufferSize: 4096,
33+
CheckOrigin: func(r *http.Request) bool {
34+
// If an API key is required, WS auth via the hello frame
35+
// protects the endpoint; allow any origin.
36+
if s.cfg.APIKey != "" {
37+
return true
38+
}
39+
// Dev mode: restrict to localhost origins.
40+
origin := r.Header.Get("Origin")
41+
if origin == "" {
42+
return true // non-browser clients (curl, wscat)
43+
}
44+
u, err := url.Parse(origin)
45+
if err != nil {
46+
return false
47+
}
48+
host := u.Hostname()
49+
return host == "localhost" || host == "127.0.0.1" || host == "::1"
50+
},
51+
}
52+
}
53+
2954
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
30-
conn, err := upgrader.Upgrade(w, r, nil)
55+
up := s.newUpgrader()
56+
conn, err := up.Upgrade(w, r, nil)
3157
if err != nil {
3258
slog.Error("websocket upgrade failed", "error", err)
3359
return

internal/memory/session_store.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package memory
33
import (
44
"context"
55
"crypto/rand"
6+
"database/sql"
67
"encoding/hex"
8+
"errors"
79
"fmt"
810
"time"
911

@@ -51,7 +53,10 @@ func (s *SQLiteSessionStore) Get(ctx context.Context, id string) (*types.Session
5153
`SELECT name, created_at, updated_at, message_count, archived FROM sessions WHERE id = ?`, id).
5254
Scan(&name, &createdAt, &updatedAt, &msgCount, &archived)
5355
if err != nil {
54-
return nil, &types.MemoryError{Op: "session_get", Message: "not found", Err: err}
56+
if errors.Is(err, sql.ErrNoRows) {
57+
return nil, types.ErrSessionNotFound
58+
}
59+
return nil, &types.MemoryError{Op: "session_get", Message: "query failed", Err: err}
5560
}
5661

5762
rec := &types.SessionRecord{

internal/types/errors.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ import "errors"
44

55
// Sentinel errors for the runtime.
66
var (
7-
ErrCancelled = errors.New("turn cancelled")
8-
ErrBudgetExceeded = errors.New("budget exceeded")
9-
ErrMaxTurns = errors.New("max turns reached")
10-
ErrTimeout = errors.New("turn timed out")
7+
ErrCancelled = errors.New("turn cancelled")
8+
ErrBudgetExceeded = errors.New("budget exceeded")
9+
ErrMaxTurns = errors.New("max turns reached")
10+
ErrTimeout = errors.New("turn timed out")
11+
ErrSessionNotFound = errors.New("session not found")
1112
)
1213

1314
// ProviderError represents an error from an LLM provider.

0 commit comments

Comments
 (0)