Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import argparse
import os
import shutil
from pathlib import Path

import nox

DIR = Path(__file__).parent.resolve()

nox.needs_version = ">=2024.3.2"
# Typing `nox` with no arguments will automatically run these two sessions
nox.options.sessions = ["typecheck", "tests"]
nox.options.default_venv_backend = "uv|virtualenv"

if os.environ.get("ENVIRONMENT") == "dev":
# Use existing venvs where possible in dev
nox.options.reuse_existing_virtualenvs = True
else:
# All other envs should have the nox venvs recreated.
nox.options.reuse_existing_virtualenvs = False

nox.options.stop_on_first_error = True


@nox.session(python="3.11")
def typecheck(session: nox.Session) -> None:
"""Run typechecker (mypy)."""
session.install("mypy", ".[flask,fastapi,starlette]")
run_args = session.posargs if session.posargs else ["src"]
session.run("mypy", *run_args)


@nox.session(python="3.11")
def tests(session: nox.Session) -> None:
"""Run all tests."""
session.install("pytest", ".")
run_args = session.posargs if session.posargs else ["tests"]
session.run("pytest", *run_args)
13 changes: 10 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@ authors = [
{ name = "Jeanette Clark", email = "jclark@nceas.ucsb.edu" },
{ name = "Matthew B. Jones", email = "jones@nceas.ucsb.edu" }
]
requires-python = ">=3.13"
requires-python = ">=3.11"
dependencies = [
"authlib>=1.7.2",
"flask>=3.1.3",
"httpx>=0.28.1",
"joserfc>=1.6.5",
"requests>=2.33.1",
"werkzeug>=3.1.8",
]

[project.optional-dependencies]
flask = [
"flask>=3.1.3",
"werkzeug>=3.1.8",
]
fastapi = [
"fastapi>=0.136.1",
Expand All @@ -42,6 +41,8 @@ build-backend = "hatchling.build"
dev = [
"pytest>=9.0.3",
"ruff>=0.15.12",
"nox",
"mypy",
]

[tool.ruff]
Expand All @@ -60,3 +61,9 @@ python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v"

[[tool.mypy.overrides]]
module = [
"authlib.*",
]
ignore_missing_imports = true
163 changes: 133 additions & 30 deletions src/dataone/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
web frameworks without hard dependencies on any particular framework.
"""

import base64
import datetime as dt
import functools
import json
import os
import re
from typing import Any

import httpx
import requests
Expand Down Expand Up @@ -97,7 +100,7 @@ def load_client_secrets(filepath: str | None = None) -> dict:
raise ConfigurationError(f"OIDC secrets file at {resolved} is not valid JSON")


def extract_token_from_header(auth_header: str):
def extract_token_from_header(auth_header: str | None):
"""Extracts and validates a Bearer token from an auth header string.

Args:
Expand Down Expand Up @@ -217,6 +220,96 @@ def decode_claims(token_str, jwks, client_id, issuer):

return claims

def is_token_valid(token: str | None, buffer_minutes: int = 1) -> bool:
"""Check if a JWT token unexpired.

Args:
token: The raw JWT string to validate.
buffer_minutes: A safety margin added to the current time to account for network
lag.

Returns:
True if the token is valid and unexpired, False otherwise.
"""
if not token:
return False
try:
parts = token.split(".")
if len(parts) < 2:
return False
payload = parts[1]
payload += "=" * ((4 - len(payload) % 4) % 4)
exp = json.loads(base64.urlsafe_b64decode(payload).decode("utf-8")).get("exp")
except Exception:
return False

if not exp:
return False
expiry_time = dt.datetime.fromtimestamp(exp, tz=dt.UTC)
e = expiry_time > (dt.datetime.now(dt.UTC) + dt.timedelta(minutes=buffer_minutes))
return e

def parse_tokens_dict(tokens: str | dict[str, Any]) -> dict[str, str]:
"""Parse and normalize a raw token payload into a validated dictionary.

Args:
tokens: A raw JSON string or dictionary containing OIDC tokens.

Returns:
A dictionary containing verified 'access_token' and/or 'refresh_token' keys.

Raises:
ValueError: If the input is malformed, missing key fields, or contains empty
strings.
"""
if isinstance(tokens, str):
try:
tokens = json.loads(tokens)
except json.JSONDecodeError as e:
raise ValueError(f"'tokens' could not be parsed as JSON: {e}")

if not isinstance(tokens, dict):
raise ValueError("'tokens' must be a dictionary or a JSON string")

if "token" in tokens and isinstance(tokens["token"], dict):
tokens = tokens["token"]

if not any(key in tokens for key in ("access_token", "refresh_token")):
raise ValueError(
"'tokens' must contain at least one of 'access_token' or 'refresh_token'"
)

normalized: dict[str, str] = {}
for key in ("access_token", "refresh_token"):
if key in tokens and tokens[key] is not None:
val = tokens[key]
if not isinstance(val, str) or len(val.strip()) == 0:
raise ValueError(f"'{key}' must be a non-empty string")
normalized[key] = val

return normalized

def refresh_tokens(refresh_url: str,
refresh_token: str,
session: requests.Session | None = None) -> dict:
"""Exchange a refresh token for a new token payload.

Args:
refresh_url: The API endpoint URL used for token renewal.
refresh_token: The OIDC refresh token string.
session: An optional requests session to use for the network request.

Returns:
A dictionary containing the fresh token payload.

Raises:
requests.exceptions.HTTPError: If the server returns an unsuccessful status
code.
"""
client = session or requests.Session()
response = client.post(refresh_url, json={"refresh_token": refresh_token})
response.raise_for_status()
return response.json()

### Factory

Expand Down Expand Up @@ -357,7 +450,7 @@ def _resolve_error(self, exc: Exception):

return "Internal authentication error", 500

def _verify_scope(self, claims: dict, required_scope: str | None):
def _verify_scope(self, claims: dict, required_scope: str | None = None):
"""Internal helper to check if the required scope exists in claims."""
if not required_scope:
return
Expand Down Expand Up @@ -433,7 +526,9 @@ def _decode_and_validate_token(self, token_str: str):

return decode_claims(token_str, jwks, client_id, issuer)

def validate_and_extract_claims(self, token_str: str, required_scope: str = None):
def validate_and_extract_claims(self,
token_str: str,
required_scope: str | None = None):
"""Validate a token string and optionally check required scope.

Args:
Expand All @@ -455,15 +550,15 @@ def validate_and_extract_claims(self, token_str: str, required_scope: str = None

return claims

def login(self, redirect_uri: str, request=None):
def login(self, redirect_uri: str, request=None) -> Any:
"""This is implemented by subclasses."""
raise NotImplementedError

def authorize(self, request=None):
def authorize(self, request=None) -> Any:
"""This is implemented by subclasses."""
raise NotImplementedError

def refresh(self, request_json: dict):
def refresh(self, request_json: dict) -> Any:
"""This is implemented by subclasses."""
raise NotImplementedError

Expand Down Expand Up @@ -582,9 +677,9 @@ async def _decode_and_validate_token(self, token_str: str):

return decode_claims(token_str, jwks, client_id, issuer)

async def validate_and_extract_claims(
self, token_str: str, required_scope: str = None
):
async def validate_and_extract_claims(self,
token_str: str,
required_scope: str | None = None):
"""Asynchronously decodes and validates a JWT using the provider's JWKS.

This overrides the base method to support Starlette/FastAPI's asynchronous
Expand All @@ -604,7 +699,7 @@ async def validate_and_extract_claims(

return claims

async def login(self, request, redirect_uri: str):
async def login(self, redirect_uri: str, request: Any = None) -> Any:
"""Asynchronously initiates the OIDC login flow.

Uses the Starlette OAuth client to generate a redirect response that
Expand All @@ -629,7 +724,7 @@ async def login(request: Request):
# The Starlette client's authorize_redirect is async
return await self.dataone_oidc.authorize_redirect(request, redirect_uri)

async def authorize(self, request):
async def authorize(self, request) -> Any: # type: ignore[override]
"""Asynchronously exchanges an authorization code for an access token.

This method is designed to be used in the OIDC callback route. It
Expand All @@ -655,7 +750,7 @@ async def authorize(request: Request):
except Exception as e:
return self._error_handler(e)

async def refresh(self, request_json: dict):
async def refresh(self, request_json: dict) -> Any:
"""Asynchronously exchanges a refresh token for new access tokens.

Overrides the synchronous base method to accommodate FastAPI's async
Expand Down Expand Up @@ -730,14 +825,18 @@ async def get_secure_data(
async def dependency(request: Request):
from fastapi import HTTPException

from .auth import extract_token_from_header

# Handle 'read_only' logic
if self.access_mode != "authenticated":
return None
# Handle 'open' logic
if self.access_mode == ACCESS_MODE_OPEN:
return {}

if methods is not None and request.method not in methods:
return None
# Handle 'read only' logic
if self.access_mode == ACCESS_MODE_READ_ONLY:
if request.method in ["POST", "PUT", "DELETE", "PATCH"]:
raise HTTPException(
status_code=403,
detail="This API is currently in read-only mode."
)
return {}

try:
auth_header = request.headers.get("Authorization")
Expand Down Expand Up @@ -790,14 +889,18 @@ async def get_secure_data(
async def dependency(request: Request):
from fastapi import HTTPException

from .auth import extract_token_from_header

# Handle 'read_only' logic
if self.access_mode != "authenticated":
return None

if methods is not None and request.method not in methods:
return None
# Handle 'open' logic
if self.access_mode == ACCESS_MODE_OPEN:
return {}

# Handle 'read only' logic
if self.access_mode == ACCESS_MODE_READ_ONLY:
if request.method in ["POST", "PUT", "DELETE", "PATCH"]:
raise HTTPException(
status_code=403,
detail="This API is currently in read-only mode."
)
return {}

try:
auth_header = request.headers.get("Authorization")
Expand Down Expand Up @@ -867,7 +970,7 @@ def _token_response(self, token: dict, message: str = "Success"):
}
), 200

def login(self, redirect_uri: str):
def login(self, redirect_uri: str, request: Any = None) -> Any:
"""Initiates the OIDC login flow for Flask.

Uses the Flask Authlib client to generate a redirect response that
Expand All @@ -889,7 +992,7 @@ def login():
"""
return self.dataone_oidc.authorize_redirect(redirect_uri)

def authorize(self):
def authorize(self) -> Any: # type: ignore[override]
"""Exchanges an authorization code for an access token in Flask.

This method should be called within the OIDC callback route. It
Expand All @@ -911,7 +1014,7 @@ def authorize():
except Exception as e:
return self._error_handler(e)

def refresh(self, request_json: dict):
def refresh(self, request_json: dict) -> Any:
"""Executes the synchronous token refresh request for Flask.

Args:
Expand Down
Empty file added src/dataone/py.typed
Empty file.
Loading