Skip to content

Commit deb41ea

Browse files
committed
Backup local DB to S3
1 parent 05290f0 commit deb41ea

17 files changed

Lines changed: 544 additions & 264 deletions

File tree

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python-envs.defaultEnvManager": "ms-python.python:system"
3+
}

api/core/auth.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Any
2+
3+
from fastapi import Header, HTTPException
4+
from fastapi.security import HTTPAuthorizationCredentials
5+
from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore
6+
from pydantic import Field
7+
8+
9+
# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py
10+
class APIKeyDependency:
11+
def __init__(self, key: str | None):
12+
self.key = key
13+
14+
def __call__(self, x_api_key: str | None = Header(...)) -> str | None:
15+
if x_api_key != self.key:
16+
raise HTTPException(status_code=401, detail="unauthorized")
17+
return x_api_key
18+
19+
20+
class GroupClaims(CognitoClaims): # type: ignore
21+
cognito_groups: list[str] | None = Field(alias="cognito:groups")
22+
23+
24+
class UserGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore
25+
user_info = GroupClaims
26+
27+
async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any:
28+
user_info = await super().call(http_auth)
29+
if "users" not in (getattr(user_info, "cognito_groups") or []):
30+
raise HTTPException(
31+
status_code=403, detail="Not a member of the 'users' group"
32+
)
33+
return user_info

api/core/database.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import gzip
2+
import os
3+
import sqlite3
4+
import tempfile
5+
import time
6+
from typing import Any, cast
7+
8+
from botocore.exceptions import ClientError
9+
from mypy_boto3_s3 import S3Client
10+
from sqlalchemy import Engine, create_engine
11+
12+
from api.models import Base
13+
14+
15+
class Database:
16+
"""Database wrapper."""
17+
18+
def __init__(
19+
self,
20+
db_url: str,
21+
connect_kwargs: dict[str, Any] | None = None,
22+
):
23+
self.db_url = db_url
24+
self.connect_kwargs = connect_kwargs or {}
25+
26+
@property
27+
def engine(self) -> Engine:
28+
if hasattr(self, "_engine"):
29+
return cast(Engine, self._engine) # type: ignore[has-type]
30+
retries = 0
31+
while True:
32+
try:
33+
engine = create_engine(self.db_url, connect_args=self.connect_kwargs)
34+
# Attempt to create a connection or perform any necessary operations
35+
engine.connect()
36+
self._engine = engine
37+
return engine # Connection successful
38+
except Exception as e:
39+
if retries >= 10:
40+
raise RuntimeError(f"Could not create engine: {str(e)}")
41+
retries += 1
42+
time.sleep(60)
43+
44+
def create(self) -> None:
45+
"""Create database tables."""
46+
Base.metadata.create_all(bind=self.engine)
47+
48+
def backup(self) -> bool:
49+
"""Backup the database. To be implemented by subclasses if supported."""
50+
return False
51+
52+
53+
class SqliteDatabase(Database):
54+
"""SQLite database wrapper with optional S3 backup support."""
55+
56+
BACKUP_KEY = "userapi_sqlite_backup/backup.db.gz"
57+
58+
def __init__(
59+
self,
60+
db_url: str,
61+
s3_client: S3Client | None = None,
62+
s3_bucket: str | None = None,
63+
):
64+
if not db_url.startswith("sqlite:///"):
65+
raise ValueError(f"SQLiteRDSJobQueue requires SQLite DB URL, got: {db_url}")
66+
if not ((s3_client is None) == (s3_bucket is None)):
67+
raise ValueError(
68+
"Both s3_client and s3_bucket must be provided for S3 backup/restore, or both must be None."
69+
)
70+
self.s3_client = s3_client
71+
self.s3_bucket = s3_bucket
72+
super().__init__(db_url, connect_kwargs={"check_same_thread": False})
73+
74+
def create(self) -> None:
75+
self._restore_database()
76+
super().create()
77+
78+
@property
79+
def db_path(self) -> str:
80+
return self.db_url[len("sqlite:///") :]
81+
82+
def backup(self) -> bool:
83+
"""Backup the SQLite database to S3."""
84+
if not self.s3_bucket or not self.s3_client:
85+
return False
86+
87+
with tempfile.TemporaryDirectory() as temp_dir:
88+
tmp_backup_path = os.path.join(temp_dir, "backup.db")
89+
tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz")
90+
with sqlite3.connect(self.db_path) as source_conn:
91+
with sqlite3.connect(tmp_backup_path) as backup_conn:
92+
source_conn.backup(backup_conn)
93+
94+
with open(tmp_backup_path, "rb") as f_in:
95+
with gzip.open(tmp_gzip_path, "wb") as f_out:
96+
f_out.writelines(f_in)
97+
self.s3_client.upload_file(tmp_gzip_path, self.s3_bucket, self.BACKUP_KEY)
98+
return True
99+
100+
def _restore_database(self) -> bool:
101+
"""Restore the SQLite database from S3."""
102+
if not self.s3_bucket or not self.s3_client:
103+
return False
104+
105+
try:
106+
self.s3_client.head_object(Bucket=self.s3_bucket, Key=self.BACKUP_KEY)
107+
except ClientError as e:
108+
if e.response["Error"]["Code"] == "404":
109+
return False
110+
raise
111+
112+
with tempfile.TemporaryDirectory() as temp_dir:
113+
tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz")
114+
tmp_backup_path = os.path.join(temp_dir, "backup.db")
115+
self.s3_client.download_file(self.s3_bucket, self.BACKUP_KEY, tmp_gzip_path)
116+
with gzip.open(tmp_gzip_path, "rb") as f_in:
117+
with open(tmp_backup_path, "wb") as f_out:
118+
f_out.write(f_in.read())
119+
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
120+
os.rename(tmp_backup_path, self.db_path)
121+
return True

api/core/filesystem.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77
from pathlib import Path, PurePosixPath
88
from typing import Any, BinaryIO, Callable, Generator
99

10-
import boto3
1110
import humanize
12-
from botocore.client import Config
1311
from botocore.response import StreamingBody
14-
from botocore.utils import fix_s3_host
1512
from fastapi import Request
1613
from fastapi.responses import FileResponse, StreamingResponse
1714
from mypy_boto3_s3 import S3Client
@@ -390,23 +387,16 @@ def download_url(
390387
def get_filesystem_with_root(
391388
root_path: str,
392389
filesystem: str,
393-
s3_region: str,
394-
s3_bucket: str | None,
390+
s3_bucket: str | None = None,
391+
s3_client: S3Client | None = None,
395392
) -> FileSystem:
396393
"""Get the filesystem to use."""
397394
predef_dirs = [e.value for e in models.UploadFileTypes] + [
398395
e.value for e in models.OutputEndpoints
399396
]
400397
if filesystem == "s3":
398+
assert s3_client is not None, "S3 client must be provided for S3 filesystem"
401399
assert s3_bucket is not None, "S3 bucket must be provided for S3 filesystem"
402-
s3_client = boto3.client(
403-
"s3",
404-
region_name=s3_region,
405-
endpoint_url=f"https://s3.{s3_region}.amazonaws.com",
406-
config=Config(signature_version="v4", s3={"addressing_style": "path"}),
407-
)
408-
# this and config=... required to avoid DNS problems with new buckets
409-
s3_client.meta.events.unregister("before-sign.s3", fix_s3_host)
410400
return S3Filesystem(root_path, s3_client, s3_bucket, predef_dirs=predef_dirs)
411401
elif filesystem == "local":
412402
return LocalFilesystem(root_path, predef_dirs=predef_dirs)
@@ -417,13 +407,13 @@ def get_filesystem_with_root(
417407
def user_filesystem_getter(
418408
user_data_root_path: str,
419409
filesystem: str,
420-
s3_region: str,
421-
s3_bucket: str | None,
410+
s3_bucket: str | None = None,
411+
s3_client: S3Client | None = None,
422412
) -> Callable[[str], FileSystem]:
423413
"""Get the filesystem to use for a user."""
424414
return lambda user_id: get_filesystem_with_root(
425415
str(Path(user_data_root_path) / user_id),
426416
filesystem=filesystem,
427-
s3_region=s3_region,
428417
s3_bucket=s3_bucket,
418+
s3_client=s3_client,
429419
)

api/database.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

api/dependencies.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,60 @@
1-
from typing import Any, Callable
1+
from typing import Any, Callable, Generator
22

3+
import boto3
34
import requests
4-
from fastapi import Depends, Header, HTTPException, Request
5+
from botocore.config import Config
6+
from botocore.utils import fix_s3_host
7+
from fastapi import Depends, HTTPException, Request
58
from fastapi.encoders import jsonable_encoder
6-
from fastapi.security import HTTPAuthorizationCredentials
7-
from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore
8-
from pydantic import BaseModel, Field
9+
from fastapi_cloudauth.cognito import CognitoClaims # type: ignore
10+
from sqlalchemy.orm import Session, sessionmaker
911

1012
from api import settings
1113
from api.core import notifications
14+
from api.core.auth import APIKeyDependency, UserGroupCognitoCurrentUser
15+
from api.core.database import Database, SqliteDatabase
1216
from api.core.filesystem import FileSystem, user_filesystem_getter
1317
from api.schemas.job import QueueJob
1418

19+
# S3 client setup
20+
s3_client = None
21+
if settings.s3_bucket:
22+
s3_client = boto3.client(
23+
"s3",
24+
region_name=settings.s3_region,
25+
endpoint_url=f"https://s3.{settings.s3_region}.amazonaws.com",
26+
config=Config(signature_version="v4", s3={"addressing_style": "path"}),
27+
)
28+
# this and config=... required to avoid DNS problems with new buckets
29+
s3_client.meta.events.unregister("before-sign.s3", fix_s3_host)
1530

16-
class GroupClaims(CognitoClaims): # type: ignore
17-
"""CognitoClaims with added groups claim."""
1831

19-
cognito_groups: list[str] | None = Field(alias="cognito:groups")
32+
# Database
33+
if settings.database_url.startswith("sqlite"):
34+
db: Database = SqliteDatabase(
35+
db_url=settings.database_url,
36+
s3_client=s3_client,
37+
s3_bucket=settings.s3_bucket,
38+
)
39+
else:
40+
db = Database(db_url=settings.database_url)
2041

2142

22-
class UserGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore
23-
"""
24-
Check membership in the 'users' group and add group membership information.
25-
"""
43+
async def db_dep() -> Database:
44+
return db
2645

27-
user_info = GroupClaims
2846

29-
async def call(
30-
self, http_auth: HTTPAuthorizationCredentials
31-
) -> BaseModel | dict[str, Any] | None:
32-
user_info = await super().call(http_auth)
33-
if "users" not in (getattr(user_info, "cognito_groups") or []):
34-
raise HTTPException(
35-
status_code=403, detail="Not a member of the 'users' group"
36-
)
37-
return user_info # type: ignore
47+
def session_dep(db_dep: Database = Depends(db_dep)) -> Generator[Session, Any, None]:
48+
"""Get database session."""
49+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=db_dep.engine)
50+
db = SessionLocal()
51+
try:
52+
yield db
53+
finally:
54+
db.close()
3855

3956

57+
# User authentication
4058
current_user_dep = UserGroupCognitoCurrentUser(
4159
region=settings.cognito_region,
4260
userPoolId=settings.cognito_user_pool_id,
@@ -52,12 +70,13 @@ async def current_user_global_dep(
5270
return current_user
5371

5472

73+
# Filesystem
5574
async def filesystem_getter_dep() -> Callable[[str], FileSystem]:
5675
"""Get the user's filesystem getter."""
5776
return user_filesystem_getter(
5877
user_data_root_path=settings.user_data_root_path,
5978
filesystem=settings.filesystem,
60-
s3_region=settings.s3_region,
79+
s3_client=s3_client,
6180
s3_bucket=settings.s3_bucket,
6281
)
6382

@@ -70,20 +89,11 @@ async def user_filesystem_dep(
7089
return filesystem_getter(current_user.username)
7190

7291

73-
class APIKeyDependency:
74-
def __init__(self, key: str | None):
75-
"""Check API-internal key."""
76-
self.key = key
77-
78-
def __call__(self, x_api_key: str | None = Header(...)) -> str | None:
79-
if x_api_key != self.key:
80-
raise HTTPException(status_code=401, detail="unauthorized")
81-
return x_api_key
82-
83-
92+
# App-internal authentication (i.e. user-facing API <-> worker-facing API)
8493
workerfacing_api_auth_dep = APIKeyDependency(settings.internal_api_key_secret)
8594

8695

96+
# Notifications
8797
async def email_sender_dep() -> notifications.EmailSender:
8898
"""Get the email sender."""
8999
service = settings.email_sender_service
@@ -110,6 +120,7 @@ async def email_sender_dep() -> notifications.EmailSender:
110120
)
111121

112122

123+
# Job enqueueing to worker-facing API
113124
async def enqueueing_function_dep() -> Callable[[QueueJob], None]:
114125
def enqueue(queue_item: QueueJob) -> None:
115126
resp = requests.post(

api/endpoints/job_update.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
import api.core.notifications as notifications
77
import api.crud.job as job_crud
8-
from api.database import get_db
9-
from api.dependencies import email_sender_dep, workerfacing_api_auth_dep
8+
from api.dependencies import email_sender_dep, session_dep, workerfacing_api_auth_dep
109
from api.models import JobStates
1110
from api.schemas.job_update import JobUpdate
1211

@@ -20,7 +19,7 @@
2019
)
2120
def update_job_status(
2221
update: JobUpdate,
23-
db: Session = Depends(get_db),
22+
db: Session = Depends(session_dep),
2423
email_sender: notifications.EmailSender = Depends(email_sender_dep),
2524
) -> JobStates:
2625
db_job = job_crud.get_job(db, update.job_id)

0 commit comments

Comments
 (0)