Skip to content

Commit 88fc27f

Browse files
committed
start migrating to Beanie ORM
1 parent 8383162 commit 88fc27f

4 files changed

Lines changed: 50 additions & 15 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ requires-python = ">=3.12"
1313
dependencies = [
1414
"fastapi",
1515
"motor",
16+
"beanie",
1617
"python-dotenv",
1718
"uvicorn",
1819
"pydantic",
1920
"mangum",
20-
"Authlib",
2121
"requests",
2222
"Starlette",
2323
"itsdangerous",

src/api/data/instance.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
1+
import logging
12
import os
23

3-
import motor.motor_asyncio
4+
from beanie import init_beanie
5+
from motor.motor_asyncio import AsyncIOMotorClient
46

5-
client = motor.motor_asyncio.AsyncIOMotorClient(
6-
os.environ.get("SOLESEARCH_DB_CONNECTION_STRING")
7-
)
7+
logger = logging.getLogger(__name__)
8+
9+
CONNECTION_STRING = os.environ.get("SOLESEARCH_DB_CONNECTION_STRING")
10+
if not CONNECTION_STRING:
11+
raise EnvironmentError(
12+
"Please set the SOLESEARCH_DB_CONNECTION_STRING environment variable."
13+
)
14+
client = AsyncIOMotorClient(CONNECTION_STRING)
15+
DATABASE_NAME = os.environ.get("SOLESEARCH_DB_NAME")
16+
if not DATABASE_NAME:
17+
DATABASE_NAME = "Sneakers"
18+
logger.warning(
19+
f"SOLESEARCH_DB_NAME environment variable not set, defaulting to {DATABASE_NAME}."
20+
)
821
db = client[os.environ.get("SOLESEARCH_DB_NAME")]
922
sneakers = db[os.environ.get("SOLESEARCH_DB_PRIMARY_COLLECTION")]
1023
DEFAULT_LIMIT = int(os.environ.get("SOLESEARCH_DEFAULT_LIMIT", 10))

src/api/main.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import os
22

33
from dotenv import load_dotenv
4-
from fastapi import FastAPI
5-
from mangum import Mangum
6-
from starlette.middleware.sessions import SessionMiddleware
74

85
if not os.environ.get("AWS_EXECUTION_ENV"):
96
load_dotenv(os.path.join(os.getcwd(), ".env"))
107

8+
from beanie import init_beanie
9+
from fastapi import FastAPI
10+
from mangum import Mangum
11+
from starlette.middleware.sessions import SessionMiddleware
12+
13+
from api.data.instance import DATABASE_NAME, client
1114
from api.routes import auth, sneakers
15+
from core.models.shoes import Sneaker
1216

1317
app = FastAPI(
1418
redoc_url=None,
@@ -17,8 +21,18 @@
1721

1822
app.add_middleware(SessionMiddleware, secret_key="some secret key here")
1923

20-
app.include_router(sneakers.router)
21-
app.include_router(auth.router)
24+
25+
@app.on_event("startup")
26+
async def startup_event():
27+
await init_beanie(
28+
database=client.get_database(DATABASE_NAME),
29+
document_models=[Sneaker],
30+
)
31+
32+
app.include_router(sneakers.router)
33+
app.include_router(auth.router)
34+
35+
2236
handler = Mangum(app)
2337

2438
if __name__ == "__main__":

src/api/routes/sneakers.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Annotated
22

3-
from fastapi import APIRouter, Query
3+
from fastapi import APIRouter, HTTPException, Query
44

55
from api.data.instance import DEFAULT_LIMIT, DEFAULT_OFFSET
66
from api.data.models import SortKey, SortOrder
7-
from api.data.queries import find_sneaker_by_id, find_sneaker_by_sku, find_sneakers
7+
from api.data.queries import find_sneakers
88
from core.models.details import Audience
9+
from core.models.shoes import Sneaker
910

1011
router = APIRouter(
1112
prefix="/sneakers",
@@ -43,9 +44,16 @@ async def get_sneakers(
4344

4445
@router.get("/{product_id}")
4546
async def get_sneaker_by_id(product_id: str):
46-
return await find_sneaker_by_id(product_id)
47+
if not product_id:
48+
raise HTTPException(status_code=400, detail="Invalid product_id")
49+
return await Sneaker.get(product_id)
4750

4851

49-
@router.get("/sku/{product_id}")
52+
@router.get("/sku/{sku}")
5053
async def get_sneaker_by_sku(sku: str, brand: str | None = None):
51-
return await find_sneaker_by_sku(sku, brand)
54+
if not sku:
55+
raise HTTPException(status_code=400, detail="Invalid sku")
56+
if brand:
57+
return await Sneaker.find_one(Sneaker.sku == sku, Sneaker.brand == brand)
58+
else:
59+
return await Sneaker.find_one(Sneaker.sku == sku)

0 commit comments

Comments
 (0)