Skip to content

Commit 17cec66

Browse files
authored
Prepare for using SQLite on AWS (backup on S3) (#33)
1 parent b19a6fc commit 17cec66

23 files changed

Lines changed: 711 additions & 370 deletions

api/core/auth.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Any
2+
3+
from fastapi import Header, HTTPException
4+
from fastapi.security import HTTPAuthorizationCredentials
5+
from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore
6+
from pydantic import Field
7+
8+
9+
# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py
10+
class APIKeyDependency:
11+
def __init__(self, key: str | None):
12+
self.key = key
13+
14+
def __call__(self, x_api_key: str | None = Header(...)) -> str | None:
15+
if x_api_key != self.key:
16+
raise HTTPException(status_code=401, detail="unauthorized")
17+
return x_api_key
18+
19+
20+
class GroupClaims(CognitoClaims): # type: ignore
21+
cognito_groups: list[str] | None = Field(alias="cognito:groups")
22+
23+
24+
class UserGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore
25+
user_info = GroupClaims
26+
27+
async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any:
28+
user_info = await super().call(http_auth)
29+
if "users" not in (getattr(user_info, "cognito_groups") or []):
30+
raise HTTPException(
31+
status_code=403, detail="Not a member of the 'users' group"
32+
)
33+
return user_info

api/core/database.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import gzip
2+
import os
3+
import sqlite3
4+
import tempfile
5+
import time
6+
from typing import Any, cast
7+
8+
from botocore.exceptions import ClientError
9+
from mypy_boto3_s3 import S3Client
10+
from sqlalchemy import Engine, create_engine
11+
12+
from api.models import Base
13+
14+
15+
class Database:
16+
"""Database wrapper."""
17+
18+
def __init__(
19+
self,
20+
db_url: str,
21+
connect_kwargs: dict[str, Any] | None = None,
22+
):
23+
self.db_url = db_url
24+
self.connect_kwargs = connect_kwargs or {}
25+
26+
@property
27+
def engine(self) -> Engine:
28+
if hasattr(self, "_engine"):
29+
return cast(Engine, self._engine) # type: ignore[has-type]
30+
retries = 0
31+
while True:
32+
try:
33+
engine = create_engine(self.db_url, connect_args=self.connect_kwargs)
34+
# Attempt to create a connection or perform any necessary operations
35+
engine.connect()
36+
self._engine = engine
37+
return engine # Connection successful
38+
except Exception as e:
39+
if retries >= 10:
40+
raise RuntimeError(f"Could not create engine: {str(e)}")
41+
retries += 1
42+
time.sleep(60)
43+
44+
def create(self) -> None:
45+
"""Create database tables."""
46+
Base.metadata.create_all(bind=self.engine)
47+
48+
def backup(self) -> bool:
49+
"""Backup the database. To be implemented by subclasses if supported."""
50+
return False
51+
52+
def empty(self) -> None:
53+
"""Empty the database by dropping and recreating all tables."""
54+
Base.metadata.drop_all(bind=self.engine)
55+
Base.metadata.create_all(bind=self.engine)
56+
57+
58+
class SqliteDatabase(Database):
59+
"""SQLite database wrapper with optional S3 backup support."""
60+
61+
BACKUP_KEY = "userapi_sqlite_backup/backup.db.gz"
62+
63+
def __init__(
64+
self,
65+
db_url: str,
66+
s3_client: S3Client | None = None,
67+
s3_bucket: str | None = None,
68+
):
69+
if not db_url.startswith("sqlite:///"):
70+
raise ValueError(f"SQLiteRDSJobQueue requires SQLite DB URL, got: {db_url}")
71+
if not ((s3_client is None) == (s3_bucket is None)):
72+
raise ValueError(
73+
"Both s3_client and s3_bucket must be provided for S3 backup/restore, or both must be None."
74+
)
75+
self.s3_client = s3_client
76+
self.s3_bucket = s3_bucket
77+
super().__init__(db_url, connect_kwargs={"check_same_thread": False})
78+
79+
def create(self) -> None:
80+
self._restore_database()
81+
super().create()
82+
83+
@property
84+
def db_path(self) -> str:
85+
return self.db_url[len("sqlite:///") :]
86+
87+
def backup(self) -> bool:
88+
"""Backup the SQLite database to S3."""
89+
if not self.s3_bucket or not self.s3_client:
90+
return False
91+
92+
with tempfile.TemporaryDirectory() as temp_dir:
93+
tmp_backup_path = os.path.join(temp_dir, "backup.db")
94+
tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz")
95+
with sqlite3.connect(self.db_path) as source_conn:
96+
with sqlite3.connect(tmp_backup_path) as backup_conn:
97+
source_conn.backup(backup_conn)
98+
99+
with open(tmp_backup_path, "rb") as f_in:
100+
with gzip.open(tmp_gzip_path, "wb") as f_out:
101+
f_out.writelines(f_in)
102+
self.s3_client.upload_file(tmp_gzip_path, self.s3_bucket, self.BACKUP_KEY)
103+
return True
104+
105+
def _restore_database(self) -> bool:
106+
"""Restore the SQLite database from S3."""
107+
if not self.s3_bucket or not self.s3_client:
108+
return False
109+
110+
try:
111+
self.s3_client.head_object(Bucket=self.s3_bucket, Key=self.BACKUP_KEY)
112+
except ClientError as e:
113+
if e.response["Error"]["Code"] == "404":
114+
return False
115+
raise
116+
117+
with tempfile.TemporaryDirectory() as temp_dir:
118+
tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz")
119+
tmp_backup_path = os.path.join(temp_dir, "backup.db")
120+
self.s3_client.download_file(self.s3_bucket, self.BACKUP_KEY, tmp_gzip_path)
121+
with gzip.open(tmp_gzip_path, "rb") as f_in:
122+
with open(tmp_backup_path, "wb") as f_out:
123+
f_out.write(f_in.read())
124+
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
125+
os.rename(tmp_backup_path, self.db_path)
126+
return True

api/core/filesystem.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,16 @@
55
import shutil
66
import zipfile
77
from pathlib import Path, PurePosixPath
8-
from typing import Any, BinaryIO, Generator, cast
8+
from typing import Any, BinaryIO, Callable, Generator
99

10-
import boto3
1110
import humanize
12-
from botocore.client import Config
1311
from botocore.response import StreamingBody
14-
from botocore.utils import fix_s3_host
1512
from fastapi import Request
1613
from fastapi.responses import FileResponse, StreamingResponse
1714
from mypy_boto3_s3 import S3Client
1815
from mypy_boto3_s3.type_defs import ObjectIdentifierTypeDef
1916

20-
from api import models, settings
17+
from api import models
2118
from api.schemas.file import FileHTTPRequest, FileInfo, FileTypes
2219

2320

@@ -387,29 +384,36 @@ def download_url(
387384
)
388385

389386

390-
def get_filesystem_with_root(root_path: str) -> FileSystem:
387+
def get_filesystem_with_root(
388+
root_path: str,
389+
filesystem: str,
390+
s3_bucket: str | None = None,
391+
s3_client: S3Client | None = None,
392+
) -> FileSystem:
391393
"""Get the filesystem to use."""
392394
predef_dirs = [e.value for e in models.UploadFileTypes] + [
393395
e.value for e in models.OutputEndpoints
394396
]
395-
if settings.filesystem == "s3":
396-
s3_client = boto3.client(
397-
"s3",
398-
region_name=settings.s3_region,
399-
endpoint_url=f"https://s3.{settings.s3_region}.amazonaws.com",
400-
config=Config(signature_version="v4", s3={"addressing_style": "path"}),
401-
)
402-
# this and config=... required to avoid DNS problems with new buckets
403-
s3_client.meta.events.unregister("before-sign.s3", fix_s3_host)
404-
return S3Filesystem(
405-
root_path, s3_client, cast(str, settings.s3_bucket), predef_dirs=predef_dirs
406-
)
407-
elif settings.filesystem == "local":
397+
if filesystem == "s3":
398+
assert s3_client is not None, "S3 client must be provided for S3 filesystem"
399+
assert s3_bucket is not None, "S3 bucket must be provided for S3 filesystem"
400+
return S3Filesystem(root_path, s3_client, s3_bucket, predef_dirs=predef_dirs)
401+
elif filesystem == "local":
408402
return LocalFilesystem(root_path, predef_dirs=predef_dirs)
409403
else:
410404
raise ValueError("Invalid filesystem setting")
411405

412406

413-
def get_user_filesystem(user_id: str) -> FileSystem:
407+
def user_filesystem_getter(
408+
user_data_root_path: str,
409+
filesystem: str,
410+
s3_bucket: str | None = None,
411+
s3_client: S3Client | None = None,
412+
) -> Callable[[str], FileSystem]:
414413
"""Get the filesystem to use for a user."""
415-
return get_filesystem_with_root(str(Path(settings.user_data_root_path) / user_id))
414+
return lambda user_id: get_filesystem_with_root(
415+
str(Path(user_data_root_path) / user_id),
416+
filesystem=filesystem,
417+
s3_bucket=s3_bucket,
418+
s3_client=s3_client,
419+
)

api/crud/job.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from sqlalchemy.orm import Session
77

88
from api import models, settings
9-
from api.core.filesystem import FileSystem, get_user_filesystem
9+
from api.core.filesystem import FileSystem
1010
from api.schemas import job as schemas
1111

1212

1313
def enqueue_job(
14-
job: models.Job, enqueueing_func: Callable[[schemas.QueueJob], None]
14+
job: models.Job,
15+
filesystem: FileSystem,
16+
enqueueing_func: Callable[[schemas.QueueJob], None],
1517
) -> None:
16-
user_fs = get_user_filesystem(user_id=job.user_id)
17-
1818
app = job.application
1919
job_config = settings.application_config.config[app["application"]][app["version"]][
2020
app["entrypoint"]
@@ -47,14 +47,14 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
4747
f"artifact/{artifact_id}"
4848
for artifact_id in job.attributes["files_down"]["artifact_ids"]
4949
]
50-
_validate_files(user_fs, [config_path] + data_paths + artifact_paths)
50+
_validate_files(filesystem, [config_path] + data_paths + artifact_paths)
5151
roots_down = handler_config["files_down"]
52-
files_down = prepare_files(config_path, roots_down["config_id"], user_fs)
52+
files_down = prepare_files(config_path, roots_down["config_id"], filesystem)
5353
for data_path in data_paths:
54-
files_down.update(prepare_files(data_path, roots_down["data_ids"], user_fs))
54+
files_down.update(prepare_files(data_path, roots_down["data_ids"], filesystem))
5555
for artifact_path in artifact_paths:
5656
files_down.update(
57-
prepare_files(artifact_path, roots_down["artifact_ids"], user_fs)
57+
prepare_files(artifact_path, roots_down["artifact_ids"], filesystem)
5858
)
5959

6060
app_specs = schemas.AppSpecs(
@@ -76,9 +76,9 @@ def prepare_files(root_in: str, root_out: str, fs: FileSystem) -> dict[str, str]
7676
)
7777

7878
paths_upload = {
79-
"output": user_fs.full_path_uri(job.paths_out["output"]),
80-
"log": user_fs.full_path_uri(job.paths_out["log"]),
81-
"artifact": user_fs.full_path_uri(job.paths_out["artifact"]),
79+
"output": filesystem.full_path_uri(job.paths_out["output"]),
80+
"log": filesystem.full_path_uri(job.paths_out["log"]),
81+
"artifact": filesystem.full_path_uri(job.paths_out["artifact"]),
8282
}
8383

8484
queue_item = schemas.QueueJob(
@@ -117,6 +117,7 @@ def _validate_files(filesystem: FileSystem, paths: list[str]) -> None:
117117

118118
def create_job(
119119
db: Session,
120+
filesystem: FileSystem,
120121
enqueueing_func: Callable[[schemas.QueueJob], None],
121122
job: schemas.JobCreate,
122123
user_id: int,
@@ -146,18 +147,17 @@ def create_job(
146147
status_code=status.HTTP_400_BAD_REQUEST,
147148
detail=ve,
148149
)
149-
enqueue_job(db_job, enqueueing_func)
150+
enqueue_job(db_job, filesystem, enqueueing_func)
150151
db.commit()
151152
db.refresh(db_job)
152153
return db_job
153154

154155

155-
def delete_job(db: Session, db_job: models.Job) -> models.Job:
156+
def delete_job(db: Session, filesystem: FileSystem, db_job: models.Job) -> models.Job:
156157
db.delete(db_job)
157-
user_fs = get_user_filesystem(user_id=db_job.user_id)
158158
for path in db_job.paths_out.values():
159159
if path[-1] != "/":
160160
path += "/"
161-
user_fs.delete(path)
161+
filesystem.delete(path)
162162
db.commit()
163163
return db_job

api/database.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)