Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 43 additions & 231 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import platform
import struct
import threading
from typing import Tuple, Dict, Optional, List
from typing import Tuple, Dict, Optional

from mssql_python.logging import logger
from mssql_python.constants import AuthType, ConstantsDDBC
from mssql_python.connection_string_parser import _ConnectionStringParser

# Module-level credential instance cache.
# Reusing credential objects allows the Azure Identity SDK's built-in
Expand All @@ -23,6 +22,17 @@
_credential_cache: Dict[object, object] = {}
_credential_cache_lock = threading.Lock()

# Canonical keys to strip when handing an Entra-token connection to ODBC.
_SENSITIVE_KEYS = frozenset({"UID", "PWD", "Trusted_Connection", "Authentication"})

# Map Authentication connection-string values to internal short names.
_AUTH_TYPE_MAP: Dict[str, str] = {
AuthType.INTERACTIVE.value: "interactive",
AuthType.DEVICE_CODE.value: "devicecode",
AuthType.DEFAULT.value: "default",
AuthType.MSI.value: "msi",
}


def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]):
"""Build a hashable cache key from auth_type and optional credential kwargs.
Expand Down Expand Up @@ -154,112 +164,36 @@ def _acquire_token(
raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e


def _extract_msi_client_id(connection_string: str) -> Optional[str]:
"""Pull UID out of a connection string for user-assigned MSI.

For ActiveDirectoryMSI, UID (when present) carries the user-assigned
identity's ``client_id``. Returns None for system-assigned MSI.

Uses the canonical ``_ConnectionStringParser`` so braced ODBC values
are handled correctly: a ``UID={hello=world}`` resolves to the value
``hello=world`` (no surrounding braces, no false split on the inner
``=``), and a semicolon inside a legitimate braced value (e.g.
``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``.
def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]:
"""
# Connection.__init__ already parsed the same string through
# _ConnectionStringParser via _construct_connection_string, so by the
# time we get here the input is guaranteed parseable. No defensive
# try/except: a parse failure now means a real bug upstream and should
# propagate, not silently degrade user-assigned MSI to system-assigned.
parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string)
uid = (parsed.get("uid") or "").strip()
return uid or None
Extract authentication type from parsed connection parameters.


def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]:
"""
Process connection parameters and extract authentication type.
Returns the internal auth type string needed for token acquisition,
or None when the driver should handle authentication natively
(e.g. Windows Interactive).

Args:
parameters: List of connection string parameters
parsed_params: Dictionary of normalized connection parameters

Returns:
Tuple[list, Optional[str]]: Modified parameters and authentication type

Raises:
ValueError: If an invalid authentication type is provided
Optional[str]: Authentication type string or None
"""
logger.debug("process_auth_parameters: Processing %d connection parameters", len(parameters))
modified_parameters = []
auth_type = None

for param in parameters:
param = param.strip()
if not param:
continue

if "=" not in param:
modified_parameters.append(param)
continue

key, value = param.split("=", 1)
key_lower = key.lower()
value_lower = value.lower()

if key_lower == "authentication":
# Check for supported authentication types and set auth_type accordingly
if value_lower == AuthType.INTERACTIVE.value:
auth_type = "interactive"
logger.debug("process_auth_parameters: Interactive authentication detected")
# Interactive authentication (browser-based); only append parameter for non-Windows
if platform.system().lower() == "windows":
logger.debug(
"process_auth_parameters: Windows platform - using native AADInteractive"
)
auth_type = None # Let Windows handle AADInteractive natively

elif value_lower == AuthType.DEVICE_CODE.value:
# Device code authentication (for devices without browser)
logger.debug("process_auth_parameters: Device code authentication detected")
auth_type = "devicecode"
elif value_lower == AuthType.DEFAULT.value:
# Default authentication (uses DefaultAzureCredential)
logger.debug("process_auth_parameters: Default Azure authentication detected")
auth_type = "default"
elif value_lower == AuthType.MSI.value:
# Managed identity authentication (system- or user-assigned)
logger.debug("process_auth_parameters: Managed identity authentication detected")
auth_type = "msi"
modified_parameters.append(param)

logger.debug(
"process_auth_parameters: Processing complete - auth_type=%s, param_count=%d",
auth_type,
len(modified_parameters),
)
return modified_parameters, auth_type


def remove_sensitive_params(parameters: List[str]) -> List[str]:
"""Remove sensitive parameters from connection string"""
logger.debug(
"remove_sensitive_params: Removing sensitive parameters - input_count=%d", len(parameters)
)
exclude_keys = [
"uid=",
"pwd=",
"trusted_connection=",
"authentication=",
]
result = [
param
for param in parameters
if not any(param.lower().startswith(exclude) for exclude in exclude_keys)
]
logger.debug(
"remove_sensitive_params: Sensitive parameters removed - output_count=%d", len(result)
)
return result
auth_type = extract_auth_type(parsed_params)
if not auth_type:
return None

# On Windows, Interactive auth is handled natively by the ODBC driver.
if auth_type == "interactive" and platform.system().lower() == "windows":
logger.debug("process_auth_parameters: Windows platform - using native AADInteractive")
return None

logger.debug("process_auth_parameters: auth_type=%s", auth_type)
return auth_type


def remove_sensitive_params(parsed_params: Dict[str, str]) -> Dict[str, str]:
"""Return a copy of *parsed_params* without credentials / auth keys."""
return {k: v for k, v in parsed_params.items() if k not in _SENSITIVE_KEYS}


def get_auth_token(
Expand Down Expand Up @@ -287,135 +221,13 @@ def get_auth_token(
return None


def extract_auth_type(connection_string: str) -> Optional[str]:
"""Extract Entra ID auth type from a connection string.

Used as a fallback when process_connection_string does not propagate
auth_type (e.g. Windows Interactive where DDBC handles auth natively).
Bulkcopy still needs the auth type to acquire a token via Azure Identity.
"""
auth_map = {
AuthType.INTERACTIVE.value: "interactive",
AuthType.DEVICE_CODE.value: "devicecode",
AuthType.DEFAULT.value: "default",
AuthType.MSI.value: "msi",
}
for part in connection_string.split(";"):
key, _, value = part.strip().partition("=")
if key.strip().lower() == "authentication":
return auth_map.get(value.strip().lower())
return None


def process_connection_string(
connection_string: str,
) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]:
"""
Process connection string and handle authentication.

NOTE: Returns a 4-tuple. Callers must unpack all four elements.
Destructuring with three names raises ``ValueError: too many values
to unpack``. The fourth element (``credential_kwargs``) is needed by
Connection.__init__ to persist credential constructor args (e.g. the
user-assigned MSI ``client_id``) for the bulkcopy fresh-token path,
since UID is stripped from the sanitized connection string.

Args:
connection_string: The connection string to process
def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]:
"""Map the Authentication connection-string value to an internal type name.

Returns:
Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]:
Processed connection string, attrs_before dict if needed, auth_type
string for bulk copy token acquisition, and credential constructor
kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on
the Connection so bulkcopy can re-use them when acquiring a fresh
token after sanitization has stripped UID from the connection
string.

Raises:
ValueError: If the connection string is invalid or empty
Returns ``"interactive"``, ``"devicecode"``, ``"default"``, ``"msi"``,
or *None* for unrecognised / absent values. This is a pure mapping with
no platform checks — use :func:`process_auth_parameters` when you need
the Windows-Interactive suppression logic.
"""
logger.debug(
"process_connection_string: Starting - conn_str_length=%d",
len(connection_string) if isinstance(connection_string, str) else 0,
)
# Check type first
if not isinstance(connection_string, str):
logger.error(
"process_connection_string: Invalid type - expected str, got %s",
type(connection_string).__name__,
)
raise ValueError("Connection string must be a string")

# Then check if empty
if not connection_string:
logger.error("process_connection_string: Connection string is empty")
raise ValueError("Connection string cannot be empty")

parameters = connection_string.split(";")
logger.debug(
"process_connection_string: Split connection string - parameter_count=%d", len(parameters)
)

# Validate that there's at least one valid parameter
if not any("=" in param for param in parameters):
logger.error(
"process_connection_string: Invalid connection string format - no key=value pairs found"
)
raise ValueError("Invalid connection string format")

modified_parameters, auth_type = process_auth_parameters(parameters)

# Capture credential kwargs (e.g. user-assigned MSI client_id) before
# remove_sensitive_params strips UID from the parameter list. Pass the
# original connection_string (not modified_parameters) so the helper can
# use the canonical _ConnectionStringParser — handles braced values like
# UID={hello=world} correctly.
credential_kwargs: Dict[str, str] = {}
if auth_type == "msi":
client_id = _extract_msi_client_id(connection_string)
if client_id:
credential_kwargs["client_id"] = client_id
logger.debug(
"process_connection_string: ActiveDirectoryMSI with UID — "
"user-assigned managed identity selected (client_id length=%d)",
len(client_id),
)
else:
logger.debug(
"process_connection_string: ActiveDirectoryMSI without UID — "
"system-assigned managed identity selected"
)

if auth_type:
logger.info(
"process_connection_string: Authentication type detected - auth_type=%s", auth_type
)
modified_parameters = remove_sensitive_params(modified_parameters)
token_struct = get_auth_token(auth_type, credential_kwargs or None)
if token_struct:
logger.info(
"process_connection_string: Token authentication configured successfully - auth_type=%s",
auth_type,
)
return (
";".join(modified_parameters) + ";",
{ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct},
auth_type,
credential_kwargs or None,
)
else:
logger.warning(
"process_connection_string: Token acquisition failed, proceeding without token"
)

logger.debug(
"process_connection_string: Connection string processing complete - has_auth=%s",
bool(auth_type),
)
return (
";".join(modified_parameters) + ";",
None,
auth_type,
credential_kwargs or None,
)
auth_value = parsed_params.get("Authentication", "").strip().lower()
return _AUTH_TYPE_MAP.get(auth_value)
Loading
Loading