Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from datetime import datetime, timezone
from contextlib import asynccontextmanager
from typing import Optional

from auth import get_current_user, get_google_oauth_url, exchange_code_for_session
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from rate_limiter import limiter

# Load .env file if present (python-dotenv)
try:
Expand All @@ -16,9 +19,9 @@
except ImportError:
pass

from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Query
from fastapi import FastAPI, File, Request, UploadFile, Form, HTTPException, Depends, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.responses import JSONResponse, RedirectResponse
from supabase import create_client, Client
from PIL import Image

Expand All @@ -32,7 +35,6 @@
_torch_available = False
print("WARNING: PyTorch not installed. Scan endpoints will return 503.")

from auth import get_current_user, get_google_oauth_url, exchange_code_for_session

# ── Configuration ─────────────────────────────────────────────────────────────
# All secrets MUST come from environment variables — no hardcoded fallbacks.
Expand Down Expand Up @@ -120,7 +122,20 @@ async def lifespan(app: FastAPI):
allow_methods=["*"],
allow_headers=["*"],
)

app.state.limiter = limiter
app.add_middleware(SlowAPIMiddleware)

@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={
"error": "rate_limit_exceeded",
"detail": "Too many requests. Please slow down.",
"retry_after": exc.headers.get("Retry-After", "60"),
},
headers={"Retry-After": exc.headers.get("Retry-After", "60")},
)

# ── Health check ──────────────────────────────────────────────────────────────
# HF Spaces polls GET /?logs=container — without this route, FastAPI returns
Expand Down Expand Up @@ -350,7 +365,10 @@ async def _upload_image(image_bytes: bytes, user_id: str, scan_id: str) -> Optio


# ── AUTH ──────────────────────────────────────────────────────────────────────

@app.get("/api/v1/health")
async def api_health_check():
"""Health check endpoint — no auth or DB required."""
return {"status": "ok"}

@app.get("/api/v1/auth/login/google")
async def login_google():
Expand Down Expand Up @@ -412,7 +430,9 @@ async def get_public_report(scan_id: str):


@app.post("/api/v1/scan")
@limiter.limit("20/minute")
async def process_scan(
request: Request,
body_image: UploadFile = File(...),
eye_image: UploadFile = File(...),
gill_image: UploadFile = File(...),
Expand Down Expand Up @@ -462,7 +482,9 @@ async def process_scan(


@app.post("/api/v1/scan-auto")
@limiter.limit("20/minute")
async def scan_auto(
request: Request,
image: UploadFile = File(...),
current_user=Depends(get_current_user),
):
Expand Down
27 changes: 27 additions & 0 deletions backend/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastapi import Request
from slowapi import Limiter


def get_user_id(request: Request) -> str:
"""
Key rate limits by authenticated Supabase user ID (request.state.user).
Falls back to client IP for unauthenticated / dev-bypass requests.
"""
user = getattr(request.state, "user", None)
if user and isinstance(user, dict):
uid = user.get("sub") or user.get("id")
if uid:
return uid
return _get_ip(request)


def _get_ip(request: Request) -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"


# Global default: 100 requests/hour per user
# Scan endpoints override this with stricter 20/minute limit
limiter = Limiter(key_func=get_user_id, default_limits=["100/hour"])
1 change: 1 addition & 0 deletions backend/requirements-base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ numpy<2.0.0
python-dotenv>=1.0.0
python-multipart>=0.0.29
httpx>=0.27.0
slowapi==0.1.9 # rate limiting for FastAPI; added for per-user scan endpoint throttling
1 change: 1 addition & 0 deletions backend/requirements-ci.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r requirements-base.txt
pytest>=8.0.0
ruff>=0.4.0
slowapi==0.1.9 # rate limiting for FastAPI; added for per-user scan endpoint throttling
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ pytest>=8.0.0
# Comment these out if you don't have GPU/model files and just want demo mode.
torch>=2.2.0
torchvision>=0.27.0
slowapi==0.1.9 # rate limiting for FastAPI; added for per-user scan endpoint throttling
45 changes: 45 additions & 0 deletions backend/tests/test_rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from fastapi.testclient import TestClient
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from main import app

client = TestClient(app, raise_server_exceptions=False)


def test_scan_endpoint_not_rate_limited_initially():
"""First request should not be rate limited (422 = missing body, not 429)."""
response = client.post("/api/v1/scan")
assert response.status_code != 429


def test_scan_auto_endpoint_not_rate_limited_initially():
"""First request to scan-auto should not be rate limited."""
response = client.post("/api/v1/scan-auto")
assert response.status_code != 429


def test_rate_limit_returns_429():
"""Exceeding limit should return 429."""
responses = [client.get("/api/v1/health") for _ in range(110)]
status_codes = [r.status_code for r in responses]
assert 429 in status_codes, f"Expected 429 in responses, got: {set(status_codes)}"


def test_rate_limit_response_shape():
"""429 response must have correct JSON fields."""
responses = [client.get("/api/v1/health") for _ in range(110)]
rate_limited = [r for r in responses if r.status_code == 429]
assert len(rate_limited) > 0
body = rate_limited[0].json()
assert "error" in body


def test_rate_limit_retry_after_header():
"""429 response must include correct status code."""
responses = [client.get("/api/v1/health") for _ in range(110)]
rate_limited = [r for r in responses if r.status_code == 429]
assert len(rate_limited) > 0
body = rate_limited[0].json()
assert body is not None
Loading