diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 9b488c6d..1623020c 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -7,11 +7,10 @@ import platform import struct import threading -from typing import Tuple, Dict, Optional, List +from typing import Tuple, Dict, Optional from mssql_python.logging import logger from mssql_python.constants import AuthType, ConstantsDDBC -from mssql_python.connection_string_parser import _ConnectionStringParser # Module-level credential instance cache. # Reusing credential objects allows the Azure Identity SDK's built-in @@ -23,6 +22,17 @@ _credential_cache: Dict[object, object] = {} _credential_cache_lock = threading.Lock() +# Canonical keys to strip when handing an Entra-token connection to ODBC. +_SENSITIVE_KEYS = frozenset({"UID", "PWD", "Trusted_Connection", "Authentication"}) + +# Map Authentication connection-string values to internal short names. +_AUTH_TYPE_MAP: Dict[str, str] = { + AuthType.INTERACTIVE.value: "interactive", + AuthType.DEVICE_CODE.value: "devicecode", + AuthType.DEFAULT.value: "default", + AuthType.MSI.value: "msi", +} + def _credential_cache_key(auth_type: str, credential_kwargs: Optional[Dict[str, str]]): """Build a hashable cache key from auth_type and optional credential kwargs. @@ -154,112 +164,36 @@ def _acquire_token( raise RuntimeError(f"Failed to create {credential_class.__name__}: {e}") from e -def _extract_msi_client_id(connection_string: str) -> Optional[str]: - """Pull UID out of a connection string for user-assigned MSI. - - For ActiveDirectoryMSI, UID (when present) carries the user-assigned - identity's ``client_id``. Returns None for system-assigned MSI. - - Uses the canonical ``_ConnectionStringParser`` so braced ODBC values - are handled correctly: a ``UID={hello=world}`` resolves to the value - ``hello=world`` (no surrounding braces, no false split on the inner - ``=``), and a semicolon inside a legitimate braced value (e.g. - ``Database={foo;uid=victim;bar}``) cannot spoof a top-level ``UID=``. +def process_auth_parameters(parsed_params: Dict[str, str]) -> Optional[str]: """ - # Connection.__init__ already parsed the same string through - # _ConnectionStringParser via _construct_connection_string, so by the - # time we get here the input is guaranteed parseable. No defensive - # try/except: a parse failure now means a real bug upstream and should - # propagate, not silently degrade user-assigned MSI to system-assigned. - parsed = _ConnectionStringParser(validate_keywords=False)._parse(connection_string) - uid = (parsed.get("uid") or "").strip() - return uid or None + Extract authentication type from parsed connection parameters. - -def process_auth_parameters(parameters: List[str]) -> Tuple[List[str], Optional[str]]: - """ - Process connection parameters and extract authentication type. + Returns the internal auth type string needed for token acquisition, + or None when the driver should handle authentication natively + (e.g. Windows Interactive). Args: - parameters: List of connection string parameters + parsed_params: Dictionary of normalized connection parameters Returns: - Tuple[list, Optional[str]]: Modified parameters and authentication type - - Raises: - ValueError: If an invalid authentication type is provided + Optional[str]: Authentication type string or None """ - logger.debug("process_auth_parameters: Processing %d connection parameters", len(parameters)) - modified_parameters = [] - auth_type = None - - for param in parameters: - param = param.strip() - if not param: - continue - - if "=" not in param: - modified_parameters.append(param) - continue - - key, value = param.split("=", 1) - key_lower = key.lower() - value_lower = value.lower() - - if key_lower == "authentication": - # Check for supported authentication types and set auth_type accordingly - if value_lower == AuthType.INTERACTIVE.value: - auth_type = "interactive" - logger.debug("process_auth_parameters: Interactive authentication detected") - # Interactive authentication (browser-based); only append parameter for non-Windows - if platform.system().lower() == "windows": - logger.debug( - "process_auth_parameters: Windows platform - using native AADInteractive" - ) - auth_type = None # Let Windows handle AADInteractive natively - - elif value_lower == AuthType.DEVICE_CODE.value: - # Device code authentication (for devices without browser) - logger.debug("process_auth_parameters: Device code authentication detected") - auth_type = "devicecode" - elif value_lower == AuthType.DEFAULT.value: - # Default authentication (uses DefaultAzureCredential) - logger.debug("process_auth_parameters: Default Azure authentication detected") - auth_type = "default" - elif value_lower == AuthType.MSI.value: - # Managed identity authentication (system- or user-assigned) - logger.debug("process_auth_parameters: Managed identity authentication detected") - auth_type = "msi" - modified_parameters.append(param) - - logger.debug( - "process_auth_parameters: Processing complete - auth_type=%s, param_count=%d", - auth_type, - len(modified_parameters), - ) - return modified_parameters, auth_type - - -def remove_sensitive_params(parameters: List[str]) -> List[str]: - """Remove sensitive parameters from connection string""" - logger.debug( - "remove_sensitive_params: Removing sensitive parameters - input_count=%d", len(parameters) - ) - exclude_keys = [ - "uid=", - "pwd=", - "trusted_connection=", - "authentication=", - ] - result = [ - param - for param in parameters - if not any(param.lower().startswith(exclude) for exclude in exclude_keys) - ] - logger.debug( - "remove_sensitive_params: Sensitive parameters removed - output_count=%d", len(result) - ) - return result + auth_type = extract_auth_type(parsed_params) + if not auth_type: + return None + + # On Windows, Interactive auth is handled natively by the ODBC driver. + if auth_type == "interactive" and platform.system().lower() == "windows": + logger.debug("process_auth_parameters: Windows platform - using native AADInteractive") + return None + + logger.debug("process_auth_parameters: auth_type=%s", auth_type) + return auth_type + + +def remove_sensitive_params(parsed_params: Dict[str, str]) -> Dict[str, str]: + """Return a copy of *parsed_params* without credentials / auth keys.""" + return {k: v for k, v in parsed_params.items() if k not in _SENSITIVE_KEYS} def get_auth_token( @@ -287,135 +221,13 @@ def get_auth_token( return None -def extract_auth_type(connection_string: str) -> Optional[str]: - """Extract Entra ID auth type from a connection string. - - Used as a fallback when process_connection_string does not propagate - auth_type (e.g. Windows Interactive where DDBC handles auth natively). - Bulkcopy still needs the auth type to acquire a token via Azure Identity. - """ - auth_map = { - AuthType.INTERACTIVE.value: "interactive", - AuthType.DEVICE_CODE.value: "devicecode", - AuthType.DEFAULT.value: "default", - AuthType.MSI.value: "msi", - } - for part in connection_string.split(";"): - key, _, value = part.strip().partition("=") - if key.strip().lower() == "authentication": - return auth_map.get(value.strip().lower()) - return None - - -def process_connection_string( - connection_string: str, -) -> Tuple[str, Optional[Dict[int, bytes]], Optional[str], Optional[Dict[str, str]]]: - """ - Process connection string and handle authentication. - - NOTE: Returns a 4-tuple. Callers must unpack all four elements. - Destructuring with three names raises ``ValueError: too many values - to unpack``. The fourth element (``credential_kwargs``) is needed by - Connection.__init__ to persist credential constructor args (e.g. the - user-assigned MSI ``client_id``) for the bulkcopy fresh-token path, - since UID is stripped from the sanitized connection string. - - Args: - connection_string: The connection string to process +def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]: + """Map the Authentication connection-string value to an internal type name. - Returns: - Tuple[str, Optional[Dict], Optional[str], Optional[Dict[str, str]]]: - Processed connection string, attrs_before dict if needed, auth_type - string for bulk copy token acquisition, and credential constructor - kwargs (e.g. user-assigned MSI ``client_id``) to be persisted on - the Connection so bulkcopy can re-use them when acquiring a fresh - token after sanitization has stripped UID from the connection - string. - - Raises: - ValueError: If the connection string is invalid or empty + Returns ``"interactive"``, ``"devicecode"``, ``"default"``, ``"msi"``, + or *None* for unrecognised / absent values. This is a pure mapping with + no platform checks — use :func:`process_auth_parameters` when you need + the Windows-Interactive suppression logic. """ - logger.debug( - "process_connection_string: Starting - conn_str_length=%d", - len(connection_string) if isinstance(connection_string, str) else 0, - ) - # Check type first - if not isinstance(connection_string, str): - logger.error( - "process_connection_string: Invalid type - expected str, got %s", - type(connection_string).__name__, - ) - raise ValueError("Connection string must be a string") - - # Then check if empty - if not connection_string: - logger.error("process_connection_string: Connection string is empty") - raise ValueError("Connection string cannot be empty") - - parameters = connection_string.split(";") - logger.debug( - "process_connection_string: Split connection string - parameter_count=%d", len(parameters) - ) - - # Validate that there's at least one valid parameter - if not any("=" in param for param in parameters): - logger.error( - "process_connection_string: Invalid connection string format - no key=value pairs found" - ) - raise ValueError("Invalid connection string format") - - modified_parameters, auth_type = process_auth_parameters(parameters) - - # Capture credential kwargs (e.g. user-assigned MSI client_id) before - # remove_sensitive_params strips UID from the parameter list. Pass the - # original connection_string (not modified_parameters) so the helper can - # use the canonical _ConnectionStringParser — handles braced values like - # UID={hello=world} correctly. - credential_kwargs: Dict[str, str] = {} - if auth_type == "msi": - client_id = _extract_msi_client_id(connection_string) - if client_id: - credential_kwargs["client_id"] = client_id - logger.debug( - "process_connection_string: ActiveDirectoryMSI with UID — " - "user-assigned managed identity selected (client_id length=%d)", - len(client_id), - ) - else: - logger.debug( - "process_connection_string: ActiveDirectoryMSI without UID — " - "system-assigned managed identity selected" - ) - - if auth_type: - logger.info( - "process_connection_string: Authentication type detected - auth_type=%s", auth_type - ) - modified_parameters = remove_sensitive_params(modified_parameters) - token_struct = get_auth_token(auth_type, credential_kwargs or None) - if token_struct: - logger.info( - "process_connection_string: Token authentication configured successfully - auth_type=%s", - auth_type, - ) - return ( - ";".join(modified_parameters) + ";", - {ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, - auth_type, - credential_kwargs or None, - ) - else: - logger.warning( - "process_connection_string: Token acquisition failed, proceeding without token" - ) - - logger.debug( - "process_connection_string: Connection string processing complete - has_auth=%s", - bool(auth_type), - ) - return ( - ";".join(modified_parameters) + ";", - None, - auth_type, - credential_kwargs or None, - ) + auth_value = parsed_params.get("Authentication", "").strip().lower() + return _AUTH_TYPE_MAP.get(auth_value) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0933560b..caf73fca 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -40,7 +40,12 @@ NotSupportedError, sqlstate_to_exception, ) -from mssql_python.auth import extract_auth_type, process_connection_string +from mssql_python.auth import ( + extract_auth_type, + process_auth_parameters, + remove_sensitive_params, + get_auth_token, +) from mssql_python.constants import ConstantsDDBC, GetInfoConstants from mssql_python.connection_string_parser import _ConnectionStringParser from mssql_python.connection_string_builder import _ConnectionStringBuilder @@ -287,7 +292,9 @@ def __init__( raise ValueError("native_uuid must be a boolean value or None") self._native_uuid = native_uuid - self.connection_str = self._construct_connection_string(connection_str, **kwargs) + self.connection_str, self._parsed_params = self._construct_connection_string( + connection_str, **kwargs + ) self._attrs_before = attrs_before or {} # Initialize encoding settings with defaults for Python 3 @@ -328,20 +335,32 @@ def __init__( # them because UID is already gone. self._credential_kwargs: Optional[Dict[str, str]] = None - # Check if the connection string contains authentication parameters - # This is important for processing the connection string correctly. - # If authentication is specified, it will be processed to handle - # different authentication types like interactive, device code, etc. - if re.search(r"authentication", self.connection_str, re.IGNORECASE): - connection_result = process_connection_string(self.connection_str) - self.connection_str = connection_result[0] - if connection_result[1]: - self._attrs_before.update(connection_result[1]) + # Handle Entra ID authentication if specified. + # The parsed dict is used directly — no re-parsing of the connection string. + if "Authentication" in self._parsed_params: + auth_type = process_auth_parameters(self._parsed_params) + + if auth_type: + # Capture credential kwargs (e.g. user-assigned MSI client_id) + # from the parsed dict *before* remove_sensitive_params strips UID. + credential_kwargs: Optional[Dict[str, str]] = None + if auth_type == "msi": + uid = (self._parsed_params.get("UID") or "").strip() + if uid: + credential_kwargs = {"client_id": uid} + + # Strip sensitive params and rebuild the connection string. + sanitized = remove_sensitive_params(self._parsed_params) + self.connection_str = _ConnectionStringBuilder(sanitized).build() + token = get_auth_token(auth_type, credential_kwargs) + if token: + self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token + self._credential_kwargs = credential_kwargs + # Store auth type so bulkcopy() can acquire a fresh token later. - # On Windows Interactive, process_connection_string returns None - # (DDBC handles auth natively), so fall back to the connection string. - self._auth_type = connection_result[2] or extract_auth_type(self.connection_str) - self._credential_kwargs = connection_result[3] + # On Windows Interactive, process_auth_parameters returns None + # (DDBC handles auth natively), so fall back to extract_auth_type. + self._auth_type = auth_type or extract_auth_type(self._parsed_params) self._closed = False self._timeout = timeout @@ -401,24 +420,25 @@ def __init__( f"Unexpected error during connection registration: {type(e).__name__}: {e}" ) - def _construct_connection_string(self, connection_str: str = "", **kwargs: Any) -> str: + def _construct_connection_string( + self, connection_str: str = "", **kwargs: Any + ) -> Tuple[str, Dict[str, str]]: """ Construct the connection string by parsing, validating, and merging parameters. - This method performs a 6-step process: 1. Parse and validate the base connection_str (validates against allowlist) 2. Normalize parameter names (e.g., addr/address -> Server, uid -> UID) 3. Merge kwargs (which override connection_str params after normalization) - 4. Build connection string from normalized, merged params - 5. Add Driver and APP parameters (always controlled by the driver) - 6. Return the final connection string + 4. Add Driver and APP (always controlled by the driver) + 5. Build and return the final connection string + parameter dictionary Args: connection_str (str): The base connection string. **kwargs: Additional key/value pairs for the connection string. Returns: - str: The constructed and validated connection string. + Tuple[str, Dict[str, str]]: The constructed connection string and + the normalized parameter dictionary. """ # Step 1: Parse base connection string with allowlist validation @@ -448,20 +468,16 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs: Any) else: logger.warning(f"Ignoring unknown connection parameter from kwargs: {key}") - # Step 4: Build connection string with merged params - builder = _ConnectionStringBuilder(normalized_params) - - # Step 5: Add Driver and APP parameters (always controlled by the driver) - # These maintain existing behavior: Driver is always hardcoded, APP is always MSSQL-Python - builder.add_param("Driver", "ODBC Driver 18 for SQL Server") - builder.add_param("APP", "MSSQL-Python") + # Step 4: Add Driver and APP (always controlled by the driver). + normalized_params["Driver"] = "ODBC Driver 18 for SQL Server" + normalized_params["APP"] = "MSSQL-Python" - # Step 6: Build final string - conn_str = builder.build() + # Step 5: Build final connection string + conn_str = _ConnectionStringBuilder(normalized_params).build() logger.info("Final connection string: %s", sanitize_connection_string(conn_str)) - return conn_str + return conn_str, normalized_params @property def timeout(self) -> int: diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index c6141ea7..590bc0c4 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -123,7 +123,7 @@ def test_connection(db_connection): def test_construct_connection_string(db_connection): # Check if the connection string is constructed correctly with kwargs # Using official ODBC parameter names - conn_str = db_connection._construct_connection_string( + conn_str, _ = db_connection._construct_connection_string( Server="localhost", UID="me", PWD="mypwd", @@ -149,7 +149,7 @@ def test_construct_connection_string(db_connection): def test_connection_string_with_attrs_before(db_connection): # Check if the connection string is constructed correctly with attrs_before # Using official ODBC parameter names - conn_str = db_connection._construct_connection_string( + conn_str, _ = db_connection._construct_connection_string( Server="localhost", UID="me", PWD="mypwd", @@ -177,7 +177,7 @@ def test_connection_string_with_attrs_before(db_connection): def test_connection_string_with_odbc_param(db_connection): # Check if the connection string is constructed correctly with ODBC parameters # Using lowercase synonyms that normalize to uppercase (uid->UID, pwd->PWD) - conn_str = db_connection._construct_connection_string( + conn_str, _ = db_connection._construct_connection_string( server="localhost", uid="me", pwd="mypwd", diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index f8df6f6f..c99ce36c 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -14,7 +14,6 @@ process_auth_parameters, remove_sensitive_params, get_auth_token, - process_connection_string, extract_auth_type, _credential_cache, _credential_cache_lock, @@ -303,144 +302,94 @@ def __init__(self): class TestProcessAuthParameters: def test_empty_parameters(self): - modified_params, auth_type = process_auth_parameters([]) - assert modified_params == [] + auth_type = process_auth_parameters({}) assert auth_type is None def test_interactive_auth_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Windows") - params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] - modified_params, auth_type = process_auth_parameters(params) - assert "Authentication=ActiveDirectoryInteractive" in modified_params + params = {"Authentication": "ActiveDirectoryInteractive", "Server": "test"} + auth_type = process_auth_parameters(params) assert auth_type is None def test_interactive_auth_non_windows(self, monkeypatch): monkeypatch.setattr(platform, "system", lambda: "Darwin") - params = ["Authentication=ActiveDirectoryInteractive", "Server=test"] - _, auth_type = process_auth_parameters(params) + params = {"Authentication": "ActiveDirectoryInteractive", "Server": "test"} + auth_type = process_auth_parameters(params) assert auth_type == "interactive" def test_device_code_auth(self): - params = ["Authentication=ActiveDirectoryDeviceCode", "Server=test"] - _, auth_type = process_auth_parameters(params) + params = {"Authentication": "ActiveDirectoryDeviceCode", "Server": "test"} + auth_type = process_auth_parameters(params) assert auth_type == "devicecode" def test_default_auth(self): - params = ["Authentication=ActiveDirectoryDefault", "Server=test"] - _, auth_type = process_auth_parameters(params) + params = {"Authentication": "ActiveDirectoryDefault", "Server": "test"} + auth_type = process_auth_parameters(params) assert auth_type == "default" def test_msi_auth(self): - params = ["Authentication=ActiveDirectoryMSI", "Server=test"] - _, auth_type = process_auth_parameters(params) + params = {"Authentication": "ActiveDirectoryMSI", "Server": "test"} + auth_type = process_auth_parameters(params) assert auth_type == "msi" def test_msi_auth_case_insensitive(self): - params = ["authentication=activedirectorymsi", "Server=test"] - _, auth_type = process_auth_parameters(params) + params = {"Authentication": "activedirectorymsi", "Server": "test"} + auth_type = process_auth_parameters(params) assert auth_type == "msi" class TestRemoveSensitiveParams: def test_remove_sensitive_parameters(self): - params = [ - "Server=test", - "UID=user", - "PWD=password", - "Encrypt=yes", - "TrustServerCertificate=yes", - "Authentication=ActiveDirectoryDefault", - "Trusted_Connection=yes", - "Database=testdb", - ] + params = { + "Server": "test", + "UID": "user", + "PWD": "password", + "Encrypt": "yes", + "TrustServerCertificate": "yes", + "Authentication": "ActiveDirectoryDefault", + "Trusted_Connection": "yes", + "Database": "testdb", + } filtered_params = remove_sensitive_params(params) - assert "Server=test" in filtered_params - assert "Database=testdb" in filtered_params - assert "UID=user" not in filtered_params - assert "PWD=password" not in filtered_params - assert "Encrypt=yes" in filtered_params - assert "TrustServerCertificate=yes" in filtered_params - assert "Trusted_Connection=yes" not in filtered_params - assert "Authentication=ActiveDirectoryDefault" not in filtered_params - - -class TestProcessConnectionString: - def test_process_connection_string_with_default_auth(self): - conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) - - assert "Server=test" in result_str - assert "Database=testdb" in result_str - assert attrs is not None - assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs - assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) - assert auth_type == "default" - assert credential_kwargs is None - - def test_process_connection_string_no_auth(self): - conn_str = "Server=test;Database=testdb;UID=user;PWD=password" - result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) - - assert "Server=test" in result_str - assert "Database=testdb" in result_str - assert "UID=user" in result_str - assert "PWD=password" in result_str - assert attrs is None - assert auth_type is None - assert credential_kwargs is None - - def test_process_connection_string_interactive_non_windows(self, monkeypatch): - monkeypatch.setattr(platform, "system", lambda: "Darwin") - conn_str = "Server=test;Authentication=ActiveDirectoryInteractive;Database=testdb" - result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) - - assert "Server=test" in result_str - assert "Database=testdb" in result_str - assert attrs is not None - assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in attrs - assert isinstance(attrs[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value], bytes) - assert auth_type == "interactive" - assert credential_kwargs is None - - -def test_error_handling(): - # Empty string should raise ValueError - with pytest.raises(ValueError, match="Connection string cannot be empty"): - process_connection_string("") - - # Invalid connection string should raise ValueError - with pytest.raises(ValueError, match="Invalid connection string format"): - process_connection_string("InvalidConnectionString") - - # Test non-string input - with pytest.raises(ValueError, match="Connection string must be a string"): - process_connection_string(None) + assert "Server" in filtered_params + assert "Database" in filtered_params + assert "UID" not in filtered_params + assert "PWD" not in filtered_params + assert "Encrypt" in filtered_params + assert "TrustServerCertificate" in filtered_params + assert "Trusted_Connection" not in filtered_params + assert "Authentication" not in filtered_params class TestExtractAuthType: def test_interactive(self): assert ( - extract_auth_type("Server=test;Authentication=ActiveDirectoryInteractive;") + extract_auth_type({"Server": "test", "Authentication": "ActiveDirectoryInteractive"}) == "interactive" ) def test_default(self): - assert extract_auth_type("Server=test;Authentication=ActiveDirectoryDefault;") == "default" + assert ( + extract_auth_type({"Server": "test", "Authentication": "ActiveDirectoryDefault"}) + == "default" + ) def test_devicecode(self): assert ( - extract_auth_type("Server=test;Authentication=ActiveDirectoryDeviceCode;") + extract_auth_type({"Server": "test", "Authentication": "ActiveDirectoryDeviceCode"}) == "devicecode" ) def test_msi(self): - assert extract_auth_type("Server=test;Authentication=ActiveDirectoryMSI;") == "msi" + assert ( + extract_auth_type({"Server": "test", "Authentication": "ActiveDirectoryMSI"}) == "msi" + ) def test_no_auth(self): - assert extract_auth_type("Server=test;Database=db;") is None + assert extract_auth_type({"Server": "test", "Database": "db"}) is None def test_unsupported_auth(self): - assert extract_auth_type("Server=test;Authentication=SqlPassword;") is None + assert extract_auth_type({"Server": "test", "Authentication": "SqlPassword"}) is None class TestManagedIdentity: @@ -481,60 +430,49 @@ def test_msi_separate_cache_entries_per_client_id(self): assert ("msi", (("client_id", "def"),)) in _credential_cache assert _credential_cache["msi"] is not _credential_cache[("msi", (("client_id", "abc"),))] - def test_process_connection_string_msi_strips_uid_and_returns_kwargs(self): - """MSI connection strings: UID is stripped from the ODBC connection - string but the client_id is captured as credential_kwargs (so it can - be persisted on the Connection for the bulkcopy fresh-token path).""" - az = sys.modules["azure.identity"] + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_msi_auth_type_stored_on_connection(self, mock_ddbc_conn): + """MSI with UID: Connection stores auth_type and credential_kwargs.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + az = sys.modules["azure.identity"] az.ManagedIdentityCredential.last_init_kwargs = None - conn_str = ( - "Server=test;Authentication=ActiveDirectoryMSI;" - "UID=11111111-2222-3333-4444-555555555555;Database=testdb" + + conn = connect( + "Server=test;Database=testdb;Authentication=ActiveDirectoryMSI;" + "UID=11111111-2222-3333-4444-555555555555" ) - result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) + assert conn._auth_type == "msi" + assert conn._credential_kwargs == {"client_id": "11111111-2222-3333-4444-555555555555"} + # UID must be stripped from the sanitized connection string + assert "UID=" not in conn.connection_str + conn.close() - assert auth_type == "msi" - assert "UID=" not in result_str - assert "Authentication=" not in result_str - assert "Server=test" in result_str - assert "Database=testdb" in result_str - assert attrs is not None - assert az.ManagedIdentityCredential.last_init_kwargs == { - "client_id": "11111111-2222-3333-4444-555555555555" - } - # client_id must be returned so Connection can persist it for the - # bulkcopy fresh-token path (UID is gone from result_str by then). - assert credential_kwargs == {"client_id": "11111111-2222-3333-4444-555555555555"} - - def test_process_connection_string_msi_system_assigned_no_kwargs(self): - """System-assigned MSI: no UID → credential_kwargs is None.""" - conn_str = "Server=test;Authentication=ActiveDirectoryMSI;Database=testdb" - _, _, auth_type, credential_kwargs = process_connection_string(conn_str) - assert auth_type == "msi" - assert credential_kwargs is None + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_msi_system_assigned_no_credential_kwargs(self, mock_ddbc_conn): + """System-assigned MSI: no UID -> credential_kwargs is None.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryMSI") + assert conn._auth_type == "msi" + assert conn._credential_kwargs is None + conn.close() - def test_msi_braced_uid_value_is_unwrapped(self): + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_msi_braced_uid_value_is_unwrapped(self, mock_ddbc_conn): """A braced UID value (UID={hello=world}) must be unwrapped by the - canonical _ConnectionStringParser; the inner '=' must NOT split the - value. Without parser-aware extraction the helper would return - '{hello=world}' verbatim and ManagedIdentityCredential would reject - it.""" - conn_str = "Server=test;Authentication=ActiveDirectoryMSI;UID={hello=world};Database=testdb" - _, _, auth_type, credential_kwargs = process_connection_string(conn_str) - assert auth_type == "msi" - assert credential_kwargs == {"client_id": "hello=world"} - - def test_msi_braced_uid_with_semicolon_is_preserved(self): - """A braced UID value containing a semicolon (legal under ODBC) must - be returned intact, not truncated at the inner ';'.""" - weird_id = "abc;def;ghi" - conn_str = ( - f"Server=test;Authentication=ActiveDirectoryMSI;" f"UID={{{weird_id}}};Database=testdb" + canonical _ConnectionStringParser; the inner '=' must NOT split.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect( + "Server=test;Authentication=ActiveDirectoryMSI;" "UID={hello=world};Database=testdb" ) - _, _, auth_type, credential_kwargs = process_connection_string(conn_str) - assert auth_type == "msi" - assert credential_kwargs == {"client_id": weird_id} + assert conn._auth_type == "msi" + assert conn._credential_kwargs == {"client_id": "hello=world"} + conn.close() def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): """Regression test (cursor.bulkcopy() end-to-end) for the silent @@ -740,20 +678,17 @@ def get_token(self, scope): class TestProcessAuthParametersEdgeCases: - """Cover empty-param and no-equals-sign branches.""" + """Cover edge cases for dict-based process_auth_parameters.""" - def test_empty_and_whitespace_params_skipped(self): - params = ["Server=test", "", " ", "Database=db"] - modified, auth_type = process_auth_parameters(params) - assert "Server=test" in modified - assert "Database=db" in modified + def test_no_authentication_key(self): + params = {"Server": "test", "Database": "db"} + auth_type = process_auth_parameters(params) assert auth_type is None - def test_param_without_equals_kept(self): - params = ["Server=test", "SomeFlag", "Database=db"] - modified, auth_type = process_auth_parameters(params) - assert "SomeFlag" in modified - assert "Server=test" in modified + def test_empty_authentication_value(self): + params = {"Server": "test", "Authentication": "", "Database": "db"} + auth_type = process_auth_parameters(params) + assert auth_type is None class TestGetAuthTokenEdgeCases: @@ -978,37 +913,3 @@ def test_token_output_correct_on_cache_miss_and_hit(self): # Same credential instance for both assert "default" in _credential_cache - - -class TestProcessConnectionStringTokenFailureFallthrough: - """Cover the path where get_auth_token returns None and - process_connection_string falls through without attrs.""" - - def test_returns_none_attrs_when_token_acquisition_fails(self): - """When auth type is detected but token acquisition fails, - process_connection_string should return (conn_str, None, auth_type, kwargs).""" - import sys - - azure_identity = sys.modules["azure.identity"] - original = azure_identity.DefaultAzureCredential - - class CredentialThatAlwaysFails: - def __init__(self): - raise RuntimeError("cannot create credential") - - try: - azure_identity.DefaultAzureCredential = CredentialThatAlwaysFails - conn_str = "Server=test;Authentication=ActiveDirectoryDefault;Database=testdb" - result_str, attrs, auth_type, credential_kwargs = process_connection_string(conn_str) - - # Auth type was detected - assert auth_type == "default" - # But token acquisition failed, so attrs is None - assert attrs is None - # Connection string is still returned (sensitive params removed) - assert "Server=test" in result_str - assert "Database=testdb" in result_str - # Default auth has no credential kwargs - assert credential_kwargs is None - finally: - azure_identity.DefaultAzureCredential = original