Skip to content

Commit 04b66ce

Browse files
committed
fix tests
1 parent deb41ea commit 04b66ce

6 files changed

Lines changed: 44 additions & 12 deletions

File tree

api/core/database.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def backup(self) -> bool:
4949
"""Backup the database. To be implemented by subclasses if supported."""
5050
return False
5151

52+
def empty(self) -> None:
53+
"""Empty the database by dropping and recreating all tables."""
54+
Base.metadata.drop_all(bind=self.engine)
55+
Base.metadata.create_all(bind=self.engine)
56+
5257

5358
class SqliteDatabase(Database):
5459
"""SQLite database wrapper with optional S3 backup support."""

api/endpoints/auth_get.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from fastapi import APIRouter, Depends
44

5-
from api.dependencies import GroupClaims, current_user_dep
5+
from api.core.auth import GroupClaims
6+
from api.dependencies import current_user_dep
67
from api.schemas.user import User
78
from api.settings import cognito_client_id, cognito_region, cognito_user_pool_id
89

api/endpoints/jobs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def list_jobs(
3232
db: Session = Depends(session_dep),
3333
) -> list[Job]:
3434
db_jobs = crud.get_jobs(db, request.state.current_user.username, offset, limit)
35-
jobs = [Job.model_validate(db_job) for db_job in db_jobs] # models -> schemas
35+
jobs = [
36+
Job.model_validate(db_job, from_attributes=True) for db_job in db_jobs
37+
] # models -> schemas
3638
return sorted(jobs, key=lambda x: x.date_created, reverse=True)
3739

3840

@@ -45,7 +47,7 @@ def describe_job(
4547
raise HTTPException(
4648
status_code=status.HTTP_404_NOT_FOUND, detail="Job not found"
4749
)
48-
return Job.model_validate(db_job)
50+
return Job.model_validate(db_job, from_attributes=True)
4951

5052

5153
@router.post(
@@ -70,7 +72,7 @@ def start_job(
7072
user_id=request.state.current_user.username,
7173
user_email=request.state.current_user.email,
7274
)
73-
return Job.model_validate(db_job)
75+
return Job.model_validate(db_job, from_attributes=True)
7476
except FileNotFoundError as e:
7577
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
7678

tests/integration/conftest.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,12 @@ def filesystem_getter(
111111
base_user_dir: str,
112112
s3_testing_bucket: S3TestingBucket,
113113
) -> Callable[[str], FileSystem]:
114+
s3_fs = isinstance(base_filesystem, S3Filesystem)
114115
return user_filesystem_getter(
115116
user_data_root_path=base_user_dir,
116-
filesystem="s3" if isinstance(base_filesystem, S3Filesystem) else "local",
117-
s3_bucket=s3_testing_bucket.bucket_name,
118-
s3_client=s3_testing_bucket.s3_client,
117+
filesystem="s3" if s3_fs else "local",
118+
s3_bucket=s3_testing_bucket.bucket_name if s3_fs else None,
119+
s3_client=s3_testing_bucket.s3_client if s3_fs else None,
119120
)
120121

121122

@@ -127,12 +128,22 @@ def user_filesystem(
127128

128129

129130
@pytest.fixture(autouse=True)
130-
def override_db_dep(db: Database, monkeypatch: pytest.MonkeyPatch) -> None:
131+
def override_db_dep(
132+
db: Database,
133+
rds_testing_instance: RDSTestingInstance,
134+
monkeypatch: pytest.MonkeyPatch,
135+
) -> Generator[None, None, None]:
131136
monkeypatch.setitem(
132137
app.dependency_overrides, # type: ignore
133138
db_dep,
134139
lambda: db,
135140
)
141+
yield
142+
# Cleanup after every test
143+
if isinstance(db, SqliteDatabase):
144+
db.empty()
145+
else:
146+
rds_testing_instance.cleanup()
136147

137148

138149
@pytest.fixture(autouse=True)
@@ -269,6 +280,12 @@ def job_attrs() -> dict[str, Any]:
269280
}
270281

271282

283+
@pytest.fixture
284+
def db_session(db: Database) -> Generator[Session, None, None]:
285+
with Session(db.engine) as session:
286+
yield session
287+
288+
272289
@pytest.fixture
273290
def jobs(
274291
test_username: str,

tests/integration/endpoints/test_files.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import requests
55
from fastapi.testclient import TestClient
66

7+
from api.core.filesystem import FileSystem, LocalFilesystem
8+
79
ENDPOINT = "/files"
810

911

@@ -167,21 +169,23 @@ def test_download_file_happy(client: TestClient, data_files: dict[str, str]) ->
167169

168170

169171
def test_get_url_file_happy(
170-
env: str, client: TestClient, data_files: dict[str, str]
172+
base_filesystem: FileSystem, client: TestClient, data_files: dict[str, str]
171173
) -> None:
172174
data_file1_name, data_file1_contents = list(data_files.items())[0]
173175
response = client.get(f"{ENDPOINT}/{data_file1_name}/url")
174176
assert response.status_code == 200
175177
request_params = response.json()
176178
if "authorization" in request_params["headers"]:
177179
del request_params["headers"]["authorization"]
178-
request_client = client if env == "local" else requests
180+
request_client = (
181+
client if isinstance(base_filesystem, LocalFilesystem) else requests
182+
)
179183
response = request_client.request(**request_params)
180184
assert response.status_code == 200, response.text
181185
assert response.content.decode("utf-8") == data_file1_contents
182186

183187

184-
def test_post_url_file_happy(env: str, client: TestClient) -> None:
188+
def test_post_url_file_happy(base_filesystem: FileSystem, client: TestClient) -> None:
185189
data_file1_name = "data/test/data_file1.txt"
186190
data_file1_contents = "data file1 contents"
187191
response = client.post(f"{ENDPOINT}/{os.path.dirname(data_file1_name)}//url")
@@ -196,7 +200,9 @@ def test_post_url_file_happy(env: str, client: TestClient) -> None:
196200
"text/plain",
197201
)
198202
}
199-
request_client = client if env == "local" else requests
203+
request_client = (
204+
client if isinstance(base_filesystem, LocalFilesystem) else requests
205+
)
200206
request_client.request(**request_params, files=files)
201207
response = client.get(
202208
f"{ENDPOINT}//", params={"recursive": True, "show_dirs": False}

tests/integration/endpoints/test_job_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_job_status_update(
3838
headers={"x-api-key": internal_api_key_secret},
3939
)
4040
assert response.status_code == 200
41+
db_session.refresh(jobs[0])
4142
job = db_session.query(Job).filter(Job.id == jobs[0].id).first()
4243
assert job is not None
4344
assert job.status == "running"

0 commit comments

Comments
 (0)