Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
python-version: ['3.10', '3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down
19 changes: 14 additions & 5 deletions agave/chalice/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from pydantic import BaseModel, ValidationError

from ..core.blueprints.decorators import copy_attributes
from ..core.query_params import (
EmptyQueryMapping,
query_params_for_url,
validate_query_params,
)


class RestApiBlueprint(Blueprint):
Expand Down Expand Up @@ -238,9 +243,13 @@ def query():
next_page = <url_for_next_items>
}
"""
params = self.current_request.query_params or dict()
query_mapping = (
self.current_request.query_params or EmptyQueryMapping()
)
try:
query_params = cls.query_validator(**params)
query_params = validate_query_params(
query_mapping, cls.query_validator
)
except ValidationError as e:
return Response(e.json(), status_code=400)

Expand Down Expand Up @@ -296,11 +305,11 @@ def _all(query: QueryParams, filters: Q):
if wants_more and has_more:
query.created_before = item_dicts[-1]['created_at']
path = self.current_request.context['resourcePath']
params = query.model_dump()
params = query_params_for_url(query)
if self.user_id_filter_required():
params.pop('user_id')
params.pop('user_id', None)
if self.platform_id_filter_required():
params.pop('platform_id')
params.pop('platform_id', None)
next_page_uri = f'{path}?{urlencode(params)}'
return dict(items=item_dicts, next_page_uri=next_page_uri)

Expand Down
4 changes: 4 additions & 0 deletions agave/core/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ def generic_query(query: QueryParams, excluded: list[str] = []) -> Q:
filters &= Q(created_at__lt=query.created_before)
if query.created_after:
filters &= Q(created_at__gt=query.created_after)
ids = getattr(query, 'ids', None)
if ids:
filters &= Q(id__in=[x.strip() for x in ids.split(',') if x.strip()])
exclude_fields = {
'created_before',
'created_after',
'active',
'limit',
'page_size',
'key',
'ids',
*excluded,
}
fields = query.model_dump(exclude=exclude_fields)
Expand Down
79 changes: 79 additions & 0 deletions agave/core/query_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

from types import UnionType
from typing import Any, Iterator, TypeVar, Union, get_args, get_origin

from pydantic import BaseModel

ModelT = TypeVar('ModelT', bound=BaseModel)


def comma_separated_list(value: str | None) -> list[str]:
if not value:
return []
return [part.strip() for part in value.split(',') if part.strip()]


def _is_list_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is list:
return True
if origin in (Union, UnionType):
return any(
_is_list_annotation(arg)
for arg in get_args(annotation)
if arg is not type(None)
)
return False


def build_query_dict(
query_mapping: Any, model_cls: type[BaseModel]
) -> dict[str, Any]:
params: dict[str, Any] = {}
for name in query_mapping:
raw = query_mapping.get(name)
if name in model_cls.model_fields:
field = model_cls.model_fields[name]
if name == 'ids' and not _is_list_annotation(field.annotation):
if raw is not None:
params[name] = raw
continue
if _is_list_annotation(field.annotation):
if raw is None:
continue
if isinstance(raw, str):
params[name] = comma_separated_list(raw)
else:
params[name] = list(raw)
else:
params[name] = raw
else:
params[name] = raw
return params


def validate_query_params(
query_mapping: Any, model_cls: type[ModelT]
) -> ModelT:
return model_cls(**build_query_dict(query_mapping, model_cls))


def query_params_for_url(query: BaseModel) -> dict[str, Any]:
params = query.model_dump()
for name, field in type(query).model_fields.items():
value = params.get(name)
if _is_list_annotation(field.annotation) and isinstance(value, list):
params[name] = ','.join(value)
return params
Comment thread
coderabbitai[bot] marked this conversation as resolved.


class EmptyQueryMapping:
def __contains__(self, key: str) -> bool:
return False

def __iter__(self) -> Iterator[str]:
return iter(())

def get(self, key: str, default: Any = None) -> Any:
return default
11 changes: 7 additions & 4 deletions agave/fastapi/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ..core.blueprints.decorators import copy_attributes
from ..core.exc import NotFoundError, UnprocessableEntity
from ..core.query_params import query_params_for_url, validate_query_params

SAMPLE_404 = {
"summary": "Not found item",
Expand Down Expand Up @@ -358,7 +359,9 @@ class QueryResponse(BaseModel):

def validate_params(request: Request):
try:
return cls.query_validator(**request.query_params)
return validate_query_params(
request.query_params, cls.query_validator
)
except ValidationError as e:
raise UnprocessableEntity(e.json())

Expand Down Expand Up @@ -430,11 +433,11 @@ async def _all(query: QueryParams, filters: Q, resource_path: str):
next_page_uri: Optional[str] = None
if wants_more and has_more:
query.created_before = item_dicts[-1]['created_at']
params = query.model_dump()
params = query_params_for_url(query)
if self.user_id_filter_required():
params.pop('user_id')
params.pop('user_id', None)
if self.platform_id_filter_required():
params.pop('platform_id')
params.pop('platform_id', None)
next_page_uri = f'{resource_path}?{urlencode(params)}'
return dict(items=item_dicts, next_page_uri=next_page_uri)

Expand Down
2 changes: 1 addition & 1 deletion agave/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.4'
__version__ = '1.5.5'
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
boto3==1.35.74
types-boto3[sqs]==1.35.74
cuenca-validations==2.1.3
cuenca-validations==2.1.36
chalice==1.31.3
mongoengine==0.29.1
fastapi==0.115.11
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
packages=find_packages(),
include_package_data=True,
package_data=dict(agave=['py.typed']),
python_requires='>=3.9',
python_requires='>=3.10',
install_requires=[
'cuenca-validations>=2.1.0,<3.0.0',
'mongoengine>=0.29.0,<0.30.0',
Expand Down Expand Up @@ -54,7 +54,6 @@
],
},
classifiers=[
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
Expand Down
37 changes: 37 additions & 0 deletions tests/blueprint/test_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,43 @@ def test_query_count_resource(
assert json_body['count'] == 1


@pytest.mark.parametrize(
"client_fixture", ["fastapi_client", "chalice_client"]
)
@pytest.mark.usefixtures('accounts')
def test_query_with_comma_separated_ids_param(
client_fixture: str,
request: pytest.FixtureRequest,
) -> None:
client = request.getfixturevalue(client_fixture)
resp = client.get('/accounts?ids=US1,US2')
assert resp.status_code == 200
assert 'items' in resp.json()


@pytest.mark.parametrize(
"client_fixture", ["fastapi_client", "chalice_client"]
)
@pytest.mark.usefixtures('accounts')
def test_query_pagination_preserves_comma_separated_ids(
client_fixture: str,
request: pytest.FixtureRequest,
accounts: list[Account],
) -> None:
client = request.getfixturevalue(client_fixture)
account_ids = [accounts[0].id, accounts[1].id]
ids_param = ','.join(account_ids)
resp = client.get(f'/accounts?ids={ids_param}&page_size=1&limit=10')
assert resp.status_code == 200
json_body = resp.json()
next_page_uri = json_body['next_page_uri']
assert next_page_uri is not None
assert f'ids={ids_param}' in next_page_uri.replace('%2C', ',')

resp = client.get(next_page_uri)
assert resp.status_code == 200


@pytest.mark.parametrize(
"client_fixture", ["fastapi_client", "chalice_client"]
)
Expand Down
20 changes: 20 additions & 0 deletions tests/core/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,23 @@ def test_generic_query_after():
query = generic_query(params)
assert "created_at__gt" in repr(query)
assert "user" not in repr(query)


def test_generic_query_filters_by_ids_comma_separated_string() -> None:
params = QueryParams.model_construct(ids='US1,US2')
query = generic_query(params)
assert 'id__in' in repr(query)
assert 'US1' in repr(query)
assert 'US2' in repr(query)


def test_generic_query_empty_ids_string() -> None:
params = QueryParams.model_construct(ids='')
query = generic_query(params)
assert 'id__in' not in repr(query)


def test_generic_query_excludes_count_field() -> None:
params = QueryParams.model_construct(count=True)
query = generic_query(params)
assert 'count' not in repr(query)
Loading
Loading