From a2e106dbf0c4592e1e75fd0c5728f230968e3518 Mon Sep 17 00:00:00 2001 From: Shantanu Mane Date: Tue, 16 Jun 2026 21:27:37 +0530 Subject: [PATCH] feat(auth): register AuthMiddleware on protected routes - Apply AuthMiddleware to /storage and /assets route groups - Run auth before presign rate limiter to reject anonymous traffic early - Keep info and health endpoints public - Add unit tests for reject paths and userID context injection --- internal/middleware/authorization_test.go | 95 +++++++++++++++++++++++ internal/router/router.go | 2 + 2 files changed, 97 insertions(+) create mode 100644 internal/middleware/authorization_test.go diff --git a/internal/middleware/authorization_test.go b/internal/middleware/authorization_test.go new file mode 100644 index 0000000..de6ae35 --- /dev/null +++ b/internal/middleware/authorization_test.go @@ -0,0 +1,95 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/rndmcodeguy20/mpiper/internal/config" + "github.com/rndmcodeguy20/mpiper/pkg/utils" + "go.uber.org/zap" +) + +// 32-byte AES-256 key for the test singleton. +const testEncryptionKey = "0123456789abcdef0123456789abcdef" + +func TestMain(m *testing.M) { + config.Init(config.EnvConfig{EncryptionKey: testEncryptionKey}) + m.Run() +} + +// newGate wraps a handler that records whether it ran with AuthMiddleware. +func newGate(t *testing.T) (http.Handler, *bool) { + t.Helper() + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + return AuthMiddleware(zap.NewNop())(next), &called +} + +func TestAuthMiddleware_RejectsUnauthenticated(t *testing.T) { + tests := []struct { + name string + header string + }{ + {"missing header", ""}, + {"non-bearer scheme", "Basic abc123"}, + {"bearer without token", "Bearer "}, + {"malformed token", "Bearer not-a-valid-token"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gate, called := newGate(t) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/assets/x/complete", nil) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } + if *called { + t.Error("next handler ran for unauthenticated request — gate leaked") + } + }) + } +} + +func TestAuthMiddleware_AllowsValidTokenAndPopulatesUserID(t *testing.T) { + const wantUserID = "user-42" + token, err := utils.GenerateToken(wantUserID, testEncryptionKey) + if err != nil { + t.Fatalf("GenerateToken: %v", err) + } + + var gotUserID string + var gotOK bool + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUserID, gotOK = GetUserID(r.Context()) + w.WriteHeader(http.StatusOK) + }) + gate := AuthMiddleware(zap.NewNop())(next) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/assets/x/complete", nil) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + + gate.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + if !gotOK { + t.Fatal("GetUserID returned ok=false — userID not injected into context") + } + if gotUserID != wantUserID { + t.Errorf("userID = %q, want %q", gotUserID, wantUserID) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 578fec0..486f54d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -144,10 +144,12 @@ func NewRouter(cfg config.EnvConfig, db *sqlx.DB, m *metrics.Metrics) *chi.Mux { }) r.Route("/storage", func(r chi.Router) { + r.Use(appMiddleware.AuthMiddleware(logger)) r.With(presignRateLimiter()).Post("/presign", assetHandler.CreateAsset) }) r.Route("/assets", func(r chi.Router) { + r.Use(appMiddleware.AuthMiddleware(logger)) r.Get("/{assetID}/complete", assetHandler.MarkAssetUploaded) }) })