diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c36c9a30..e34cb9fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,8 +46,7 @@ jobs: "3.11", "3.12", "3.13", - "pypy-3.9", - "pypy-3.10", + "pypy-3.11", ] trino: [ "latest", @@ -76,7 +75,12 @@ jobs: sudo apt-get update sudo apt-get install libkrb5-dev pip install wheel - pip install .[tests,gssapi] sqlalchemy${{ matrix.sqlalchemy }} + # gssapi is CPython-only; skip it on PyPy + if [[ "${{ matrix.python }}" == pypy-* ]]; then + pip install .[tests] sqlalchemy${{ matrix.sqlalchemy }} + else + pip install .[tests,gssapi] sqlalchemy${{ matrix.sqlalchemy }} + fi - name: Run tests run: | pytest -s tests/ diff --git a/setup.py b/setup.py index 0bd7102a..76d67490 100755 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ # We don't add localstorage_require to all_require as users must explicitly opt in to use keyring. all_require = kerberos_require + sqlalchemy_require -tests_require = all_require + gssapi_require + [ +tests_require = all_require + [ # httpretty >= 1.1 duplicates requests in `httpretty.latest_requests` # https://github.com/gabrielfalcao/HTTPretty/issues/425 "httpretty < 1.1", diff --git a/tests/unit/test_auth_gssapi.py b/tests/unit/test_auth_gssapi.py new file mode 100644 index 00000000..f4015f50 --- /dev/null +++ b/tests/unit/test_auth_gssapi.py @@ -0,0 +1,87 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import nullcontext as does_not_raise +from typing import Any + +import pytest +import requests + +gssapi = pytest.importorskip("gssapi", exc_type=ImportError) + +from trino.auth import GSSAPIAuthentication # noqa: E402 + + +class MockGssapiCredentials: + def __init__(self, name: gssapi.Name, usage: str): + self.name = name + self.usage = usage + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, MockGssapiCredentials): + return False + return ( + self.name == other.name, + self.usage == other.usage, + ) + + +@pytest.fixture +def mock_gssapi_creds(monkeypatch): + monkeypatch.setattr("gssapi.Credentials", MockGssapiCredentials) + + +def _gssapi_uname(spn: str): + return gssapi.Name(spn, gssapi.NameType.user) + + +def _gssapi_sname(principal: str): + return gssapi.Name(principal, gssapi.NameType.hostbased_service) + + +@pytest.mark.parametrize( + "options, expected_credentials, expected_hostname, expected_exception", + [ + ( + {}, None, None, does_not_raise(), + ), + ( + {"hostname_override": "foo"}, None, "foo", does_not_raise(), + ), + ( + {"service_name": "bar"}, None, None, + pytest.raises(ValueError, match=r"must be used together with hostname_override"), + ), + ( + {"hostname_override": "foo", "service_name": "bar"}, None, _gssapi_sname("bar@foo"), does_not_raise(), + ), + ( + {"principal": "foo"}, MockGssapiCredentials(_gssapi_uname("foo"), "initial"), None, does_not_raise(), + ), + ] +) +def test_authentication_gssapi_init_arguments( + options, + expected_credentials, + expected_hostname, + expected_exception, + mock_gssapi_creds, + monkeypatch, +): + auth = GSSAPIAuthentication(**options) + + session = requests.Session() + + with expected_exception: + auth.set_http_session(session) + + assert session.auth.target_name == expected_hostname + assert session.auth.creds == expected_credentials diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f011a54d..29f3f388 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -13,8 +13,6 @@ import time import urllib import uuid -from contextlib import nullcontext as does_not_raise -from typing import Any from typing import Dict from typing import Optional from unittest import mock @@ -22,7 +20,6 @@ from urllib.parse import urlparse from zoneinfo import ZoneInfoNotFoundError -import gssapi import httpretty import keyring try: @@ -32,8 +29,6 @@ import pytest import requests from httpretty import httprettified -from requests_gssapi.exceptions import SPNEGOExchangeError -from requests_kerberos.exceptions import KerberosExchangeError from tzlocal import get_localzone_name # type: ignore import trino.exceptions @@ -51,8 +46,6 @@ from trino import constants from trino.auth import _OAuth2KeyRingTokenCache from trino.auth import _OAuth2TokenBearer -from trino.auth import GSSAPIAuthentication -from trino.auth import KerberosAuthentication from trino.client import _DelayExponential from trino.client import _retry_with from trino.client import _RetryWithExponentialBackoff @@ -62,6 +55,29 @@ from trino.client import TrinoRequest from trino.client import TrinoResult +try: + from requests_kerberos.exceptions import KerberosExchangeError + from trino.auth import KerberosAuthentication +except ImportError: + KerberosAuthentication = None + KerberosExchangeError = None + +try: + from requests_gssapi.exceptions import SPNEGOExchangeError + from trino.auth import GSSAPIAuthentication +except ImportError: + GSSAPIAuthentication = None + SPNEGOExchangeError = None + +requires_kerberos = pytest.mark.skipif( + KerberosAuthentication is None, + reason="requests_kerberos is not installed", +) +requires_gssapi = pytest.mark.skipif( + GSSAPIAuthentication is None, + reason="gssapi is not available (CPython-only)", +) + @mock.patch("trino.client.TrinoRequest.http") def test_trino_initial_request(mock_requests, sample_post_response_data): @@ -914,73 +930,6 @@ def __str__(self): assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=changed" -class MockGssapiCredentials: - def __init__(self, name: gssapi.Name, usage: str): - self.name = name - self.usage = usage - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, MockGssapiCredentials): - return False - return ( - self.name == other.name, - self.usage == other.usage, - ) - - -@pytest.fixture -def mock_gssapi_creds(monkeypatch): - monkeypatch.setattr("gssapi.Credentials", MockGssapiCredentials) - - -def _gssapi_uname(spn: str): - return gssapi.Name(spn, gssapi.NameType.user) - - -def _gssapi_sname(principal: str): - return gssapi.Name(principal, gssapi.NameType.hostbased_service) - - -@pytest.mark.parametrize( - "options, expected_credentials, expected_hostname, expected_exception", - [ - ( - {}, None, None, does_not_raise(), - ), - ( - {"hostname_override": "foo"}, None, "foo", does_not_raise(), - ), - ( - {"service_name": "bar"}, None, None, - pytest.raises(ValueError, match=r"must be used together with hostname_override"), - ), - ( - {"hostname_override": "foo", "service_name": "bar"}, None, _gssapi_sname("bar@foo"), does_not_raise(), - ), - ( - {"principal": "foo"}, MockGssapiCredentials(_gssapi_uname("foo"), "initial"), None, does_not_raise(), - ), - ] -) -def test_authentication_gssapi_init_arguments( - options, - expected_credentials, - expected_hostname, - expected_exception, - mock_gssapi_creds, - monkeypatch, -): - auth = GSSAPIAuthentication(**options) - - session = requests.Session() - - with expected_exception: - auth.set_http_session(session) - - assert session.auth.target_name == expected_hostname - assert session.auth.creds == expected_credentials - - class RetryRecorder: def __init__(self, error=None, result=None): self.__name__ = "RetryRecorder" @@ -1003,8 +952,8 @@ def retry_count(self): @pytest.mark.parametrize( "auth_class, retry_exception_class", [ - (KerberosAuthentication, KerberosExchangeError), - (GSSAPIAuthentication, SPNEGOExchangeError), + pytest.param(KerberosAuthentication, KerberosExchangeError, marks=requires_kerberos), + pytest.param(GSSAPIAuthentication, SPNEGOExchangeError, marks=requires_gssapi), ] ) def test_authentication_fail_retry(auth_class, retry_exception_class, monkeypatch):