1- from typing import Any , Callable
1+ from typing import Any , Callable , Generator
22
3+ import boto3
34import requests
4- from fastapi import Depends , Header , HTTPException , Request
5+ from botocore .config import Config
6+ from botocore .utils import fix_s3_host
7+ from fastapi import Depends , HTTPException , Request
58from fastapi .encoders import jsonable_encoder
6- from fastapi .security import HTTPAuthorizationCredentials
7- from fastapi_cloudauth .cognito import CognitoClaims , CognitoCurrentUser # type: ignore
8- from pydantic import BaseModel , Field
9+ from fastapi_cloudauth .cognito import CognitoClaims # type: ignore
10+ from sqlalchemy .orm import Session , sessionmaker
911
1012from api import settings
1113from api .core import notifications
14+ from api .core .auth import APIKeyDependency , UserGroupCognitoCurrentUser
15+ from api .core .database import Database , SqliteDatabase
1216from api .core .filesystem import FileSystem , user_filesystem_getter
1317from api .schemas .job import QueueJob
1418
19+ # S3 client setup
20+ s3_client = None
21+ if settings .s3_bucket :
22+ s3_client = boto3 .client (
23+ "s3" ,
24+ region_name = settings .s3_region ,
25+ endpoint_url = f"https://s3.{ settings .s3_region } .amazonaws.com" ,
26+ config = Config (signature_version = "v4" , s3 = {"addressing_style" : "path" }),
27+ )
28+ # this and config=... required to avoid DNS problems with new buckets
29+ s3_client .meta .events .unregister ("before-sign.s3" , fix_s3_host )
1530
16- class GroupClaims (CognitoClaims ): # type: ignore
17- """CognitoClaims with added groups claim."""
1831
19- cognito_groups : list [str ] | None = Field (alias = "cognito:groups" )
32+ # Database
33+ if settings .database_url .startswith ("sqlite" ):
34+ db : Database = SqliteDatabase (
35+ db_url = settings .database_url ,
36+ s3_client = s3_client ,
37+ s3_bucket = settings .s3_bucket ,
38+ )
39+ else :
40+ db = Database (db_url = settings .database_url )
2041
2142
22- class UserGroupCognitoCurrentUser (CognitoCurrentUser ): # type: ignore
23- """
24- Check membership in the 'users' group and add group membership information.
25- """
43+ async def db_dep () -> Database :
44+ return db
2645
27- user_info = GroupClaims
2846
29- async def call (
30- self , http_auth : HTTPAuthorizationCredentials
31- ) -> BaseModel | dict [str , Any ] | None :
32- user_info = await super ().call (http_auth )
33- if "users" not in (getattr (user_info , "cognito_groups" ) or []):
34- raise HTTPException (
35- status_code = 403 , detail = "Not a member of the 'users' group"
36- )
37- return user_info # type: ignore
47+ def session_dep (db_dep : Database = Depends (db_dep )) -> Generator [Session , Any , None ]:
48+ """Get database session."""
49+ SessionLocal = sessionmaker (autocommit = False , autoflush = False , bind = db_dep .engine )
50+ db = SessionLocal ()
51+ try :
52+ yield db
53+ finally :
54+ db .close ()
3855
3956
57+ # User authentication
4058current_user_dep = UserGroupCognitoCurrentUser (
4159 region = settings .cognito_region ,
4260 userPoolId = settings .cognito_user_pool_id ,
@@ -52,12 +70,13 @@ async def current_user_global_dep(
5270 return current_user
5371
5472
73+ # Filesystem
5574async def filesystem_getter_dep () -> Callable [[str ], FileSystem ]:
5675 """Get the user's filesystem getter."""
5776 return user_filesystem_getter (
5877 user_data_root_path = settings .user_data_root_path ,
5978 filesystem = settings .filesystem ,
60- s3_region = settings . s3_region ,
79+ s3_client = s3_client ,
6180 s3_bucket = settings .s3_bucket ,
6281 )
6382
@@ -70,20 +89,11 @@ async def user_filesystem_dep(
7089 return filesystem_getter (current_user .username )
7190
7291
73- class APIKeyDependency :
74- def __init__ (self , key : str | None ):
75- """Check API-internal key."""
76- self .key = key
77-
78- def __call__ (self , x_api_key : str | None = Header (...)) -> str | None :
79- if x_api_key != self .key :
80- raise HTTPException (status_code = 401 , detail = "unauthorized" )
81- return x_api_key
82-
83-
92+ # App-internal authentication (i.e. user-facing API <-> worker-facing API)
8493workerfacing_api_auth_dep = APIKeyDependency (settings .internal_api_key_secret )
8594
8695
96+ # Notifications
8797async def email_sender_dep () -> notifications .EmailSender :
8898 """Get the email sender."""
8999 service = settings .email_sender_service
@@ -110,6 +120,7 @@ async def email_sender_dep() -> notifications.EmailSender:
110120 )
111121
112122
123+ # Job enqueueing to worker-facing API
113124async def enqueueing_function_dep () -> Callable [[QueueJob ], None ]:
114125 def enqueue (queue_item : QueueJob ) -> None :
115126 resp = requests .post (
0 commit comments