Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ jobs:
"3.11",
"3.12",
"3.13",
"pypy-3.9",
"pypy-3.10",
"pypy-3.11",
]
trino: [
"latest",
Expand Down Expand Up @@ -76,7 +75,12 @@ jobs:
sudo apt-get update
sudo apt-get install libkrb5-dev
pip install wheel
pip install .[tests,gssapi] sqlalchemy${{ matrix.sqlalchemy }}
# gssapi is CPython-only; skip it on PyPy
if [[ "${{ matrix.python }}" == pypy-* ]]; then
pip install .[tests] sqlalchemy${{ matrix.sqlalchemy }}
else
pip install .[tests,gssapi] sqlalchemy${{ matrix.sqlalchemy }}
fi
- name: Run tests
run: |
pytest -s tests/
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
all_require = kerberos_require + sqlalchemy_require

tests_require = all_require + gssapi_require + [
tests_require = all_require + [
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
# https://github.com/gabrielfalcao/HTTPretty/issues/425
"httpretty < 1.1",
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/test_auth_gssapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext as does_not_raise
from typing import Any

import pytest
import requests

gssapi = pytest.importorskip("gssapi", exc_type=ImportError)

from trino.auth import GSSAPIAuthentication # noqa: E402


class MockGssapiCredentials:
def __init__(self, name: gssapi.Name, usage: str):
self.name = name
self.usage = usage

def __eq__(self, other: Any) -> bool:
if not isinstance(other, MockGssapiCredentials):
return False
return (
self.name == other.name,
self.usage == other.usage,
)


@pytest.fixture
def mock_gssapi_creds(monkeypatch):
monkeypatch.setattr("gssapi.Credentials", MockGssapiCredentials)


def _gssapi_uname(spn: str):
return gssapi.Name(spn, gssapi.NameType.user)


def _gssapi_sname(principal: str):
return gssapi.Name(principal, gssapi.NameType.hostbased_service)


@pytest.mark.parametrize(
"options, expected_credentials, expected_hostname, expected_exception",
[
(
{}, None, None, does_not_raise(),
),
(
{"hostname_override": "foo"}, None, "foo", does_not_raise(),
),
(
{"service_name": "bar"}, None, None,
pytest.raises(ValueError, match=r"must be used together with hostname_override"),
),
(
{"hostname_override": "foo", "service_name": "bar"}, None, _gssapi_sname("bar@foo"), does_not_raise(),
),
(
{"principal": "foo"}, MockGssapiCredentials(_gssapi_uname("foo"), "initial"), None, does_not_raise(),
),
]
)
def test_authentication_gssapi_init_arguments(
options,
expected_credentials,
expected_hostname,
expected_exception,
mock_gssapi_creds,
monkeypatch,
):
auth = GSSAPIAuthentication(**options)

session = requests.Session()

with expected_exception:
auth.set_http_session(session)

assert session.auth.target_name == expected_hostname
assert session.auth.creds == expected_credentials
101 changes: 25 additions & 76 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
import time
import urllib
import uuid
from contextlib import nullcontext as does_not_raise
from typing import Any
from typing import Dict
from typing import Optional
from unittest import mock
from unittest import TestCase
from urllib.parse import urlparse
from zoneinfo import ZoneInfoNotFoundError

import gssapi
import httpretty
import keyring
try:
Expand All @@ -32,8 +29,6 @@
import pytest
import requests
from httpretty import httprettified
from requests_gssapi.exceptions import SPNEGOExchangeError
from requests_kerberos.exceptions import KerberosExchangeError
from tzlocal import get_localzone_name # type: ignore

import trino.exceptions
Expand All @@ -51,8 +46,6 @@
from trino import constants
from trino.auth import _OAuth2KeyRingTokenCache
from trino.auth import _OAuth2TokenBearer
from trino.auth import GSSAPIAuthentication
from trino.auth import KerberosAuthentication
from trino.client import _DelayExponential
from trino.client import _retry_with
from trino.client import _RetryWithExponentialBackoff
Expand All @@ -62,6 +55,29 @@
from trino.client import TrinoRequest
from trino.client import TrinoResult

try:
from requests_kerberos.exceptions import KerberosExchangeError
from trino.auth import KerberosAuthentication
except ImportError:
KerberosAuthentication = None
KerberosExchangeError = None

try:
from requests_gssapi.exceptions import SPNEGOExchangeError
from trino.auth import GSSAPIAuthentication
except ImportError:
GSSAPIAuthentication = None
SPNEGOExchangeError = None

requires_kerberos = pytest.mark.skipif(
KerberosAuthentication is None,
reason="requests_kerberos is not installed",
)
requires_gssapi = pytest.mark.skipif(
GSSAPIAuthentication is None,
reason="gssapi is not available (CPython-only)",
)


@mock.patch("trino.client.TrinoRequest.http")
def test_trino_initial_request(mock_requests, sample_post_response_data):
Expand Down Expand Up @@ -914,73 +930,6 @@ def __str__(self):
assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=changed"


class MockGssapiCredentials:
def __init__(self, name: gssapi.Name, usage: str):
self.name = name
self.usage = usage

def __eq__(self, other: Any) -> bool:
if not isinstance(other, MockGssapiCredentials):
return False
return (
self.name == other.name,
self.usage == other.usage,
)


@pytest.fixture
def mock_gssapi_creds(monkeypatch):
monkeypatch.setattr("gssapi.Credentials", MockGssapiCredentials)


def _gssapi_uname(spn: str):
return gssapi.Name(spn, gssapi.NameType.user)


def _gssapi_sname(principal: str):
return gssapi.Name(principal, gssapi.NameType.hostbased_service)


@pytest.mark.parametrize(
"options, expected_credentials, expected_hostname, expected_exception",
[
(
{}, None, None, does_not_raise(),
),
(
{"hostname_override": "foo"}, None, "foo", does_not_raise(),
),
(
{"service_name": "bar"}, None, None,
pytest.raises(ValueError, match=r"must be used together with hostname_override"),
),
(
{"hostname_override": "foo", "service_name": "bar"}, None, _gssapi_sname("bar@foo"), does_not_raise(),
),
(
{"principal": "foo"}, MockGssapiCredentials(_gssapi_uname("foo"), "initial"), None, does_not_raise(),
),
]
)
def test_authentication_gssapi_init_arguments(
options,
expected_credentials,
expected_hostname,
expected_exception,
mock_gssapi_creds,
monkeypatch,
):
auth = GSSAPIAuthentication(**options)

session = requests.Session()

with expected_exception:
auth.set_http_session(session)

assert session.auth.target_name == expected_hostname
assert session.auth.creds == expected_credentials


class RetryRecorder:
def __init__(self, error=None, result=None):
self.__name__ = "RetryRecorder"
Expand All @@ -1003,8 +952,8 @@ def retry_count(self):
@pytest.mark.parametrize(
"auth_class, retry_exception_class",
[
(KerberosAuthentication, KerberosExchangeError),
(GSSAPIAuthentication, SPNEGOExchangeError),
pytest.param(KerberosAuthentication, KerberosExchangeError, marks=requires_kerberos),
pytest.param(GSSAPIAuthentication, SPNEGOExchangeError, marks=requires_gssapi),
]
)
def test_authentication_fail_retry(auth_class, retry_exception_class, monkeypatch):
Expand Down
Loading