From ec98d8fe855fa6778133d17b1da1c1132f7f141b Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Tue, 27 Jan 2026 21:40:54 +0530 Subject: [PATCH 01/23] SK-2504: Add support for custom tokenUri (#228) * SK-2504: add support for custom token uri --- skyflow/service_account/_utils.py | 80 +++++--- skyflow/utils/__init__.py | 2 +- skyflow/utils/_helpers.py | 9 +- skyflow/utils/_skyflow_messages.py | 5 + skyflow/utils/validations/_validations.py | 19 +- skyflow/vault/client/client.py | 5 +- tests/service_account/test__utils.py | 71 ++++++- tests/utils/test__helpers.py | 27 ++- tests/utils/test__utils.py | 190 ++++++++++++++++++- tests/utils/validations/test__validations.py | 75 +++++++- tests/vault/client/test__client.py | 22 ++- 11 files changed, 454 insertions(+), 51 deletions(-) diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 715716d8..a6044af2 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -2,10 +2,13 @@ import datetime import time import jwt +from urllib.parse import urlparse from skyflow.error import SkyflowError from skyflow.service_account.client.auth_client import AuthClient from skyflow.utils.logger import log_info, log_error_log from skyflow.utils import get_base_url, format_scope, SkyflowMessages +from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError +from skyflow.utils import is_valid_url invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value @@ -78,7 +81,14 @@ def get_service_account_token(credentials, options, logger): except: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) - + + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + if options and "token_uri" in options: + token_uri = options["token_uri"] + signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) auth_client = AuthClient(base_url) @@ -88,10 +98,17 @@ def get_service_account_token(credentials, options, logger): if options and "role_ids" in options: formatted_scope = format_scope(options.get("role_ids")) - response = auth_api.authentication_service_get_auth_token(assertion = signed_token, + try: + response = auth_api.authentication_service_get_auth_token(assertion = signed_token, grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", scope=formatted_scope) - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + except UnauthorizedError: + log_error_log(SkyflowMessages.ErrorLogs.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, invalid_input_error_code) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.FAILED_TO_GET_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value, invalid_input_error_code) return response.access_token, response.token_type def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): @@ -112,32 +129,41 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): - try: - expiry_time = int(time.time()) + options.get("time_to_live", 60) - prefix = "signed_token_" - - if options and options.get("data_tokens"): - for token in options["data_tokens"]: - claims = { - "iss": "sdk", - "key": credentials_obj.get("keyID"), - "exp": expiry_time, - "sub": credentials_obj.get("clientID"), - "tok": token, - "iat": int(time.time()), - } - - if "ctx" in options: - claims["ctx"] = options["ctx"] - - private_key = credentials_obj.get("privateKey") + expiry_time = int(time.time()) + options.get("time_to_live", 60) + prefix = "signed_token_" + + token_uri = credentials_obj.get("tokenURI") + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + if options and "token_uri" in options: + token_uri = options["token_uri"] + + + if options and options.get("data_tokens"): + for token in options["data_tokens"]: + claims = { + "iss": "sdk", + "key": credentials_obj.get("keyID"), + "exp": expiry_time, + "sub": credentials_obj.get("clientID"), + "tok": token, + "iat": int(time.time()), + } + + if "ctx" in options: + claims["ctx"] = options["ctx"] + + private_key = credentials_obj.get("privateKey") + try: signed_jwt = jwt.encode(claims, private_key, algorithm="RS256") - response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) - log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) - return response_object + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) + log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) + return response_object def generate_signed_data_tokens(credentials_file_path, options): diff --git a/skyflow/utils/__init__.py b/skyflow/utils/__init__.py index f2788b11..664cf65d 100644 --- a/skyflow/utils/__init__.py +++ b/skyflow/utils/__init__.py @@ -1,5 +1,5 @@ from ..utils.enums import LogLevel, Env, TokenType from ._skyflow_messages import SkyflowMessages from ._version import SDK_VERSION -from ._helpers import get_base_url, format_scope +from ._helpers import get_base_url, format_scope, is_valid_url from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_detected_entity_to_entity_info diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py index 97eecabc..090f3a2b 100644 --- a/skyflow/utils/_helpers.py +++ b/skyflow/utils/_helpers.py @@ -8,4 +8,11 @@ def get_base_url(url): def format_scope(scopes): if not scopes: return None - return " ".join([f"role:{scope}" for scope in scopes]) \ No newline at end of file + return " ".join([f"role:{scope}" for scope in scopes]) + +def is_valid_url(url): + try: + result = urlparse(url) + return all([result.scheme in ("http", "https"), result.netloc]) + except Exception: + return False \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..99329978 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -153,10 +153,13 @@ class Error(Enum): MISSING_CLIENT_ID = f"{error_prefix} Initialization failed. Unable to read client ID in credentials. Verify your client ID." MISSING_KEY_ID = f"{error_prefix} Initialization failed. Unable to read key ID in credentials. Verify your key ID." MISSING_TOKEN_URI = f"{error_prefix} Initialization failed. Unable to read token URI in credentials. Verify your token URI." + INVALID_TOKEN_URI = f"{error_prefix} Initialization failed. Invalid Skyflow credentials. The token URI must be a string and a valid URL." JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials." JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials." FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents." INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." INVALID_TEXT_IN_DEIDENTIFY= f"{error_prefix} Validation error. The text field is required and must be a non-empty string. Specify a valid text." INVALID_ENTITIES_IN_DEIDENTIFY= f"{error_prefix} Validation error. The entities field must be an array of DetectEntities enums. Specify a valid entities." @@ -332,6 +335,8 @@ class ErrorLogs(Enum): KEY_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Key ID is required." TOKEN_URI_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Token URI is required." INVALID_TOKEN_URI = f"{ERROR}: [{error_prefix}] Invalid value for token URI in credentials." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." TABLE_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Table is required." diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..611efdae 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -10,6 +10,7 @@ from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest from skyflow.vault.detect._file_input import FileInput +from skyflow.utils._helpers import is_valid_url valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] @@ -138,6 +139,15 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) + + if "token_uri" in credentials: + token_uri = credentials.get("token_uri") + if ( + token_uri is None + or not isinstance(token_uri, str) + or not is_valid_url(token_uri) + ): + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): @@ -202,10 +212,8 @@ def validate_update_vault_config(logger, config): if "env" in config and config.get("env") not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) - - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + if "credentials" in config and config.get("credentials"): + validate_credentials(logger, config.get("credentials"), "vault", vault_id) return True @@ -413,9 +421,6 @@ def validate_insert_request(logger, request): if key is None or key == "": log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger) - if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger) - if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..689fb6e9 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -1,3 +1,4 @@ +from skyflow.error import SkyflowError from skyflow.generated.rest.client import Skyflow from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages @@ -62,6 +63,8 @@ def get_bearer_token(self, credentials): "role_ids": self.__config.get("roles"), "ctx": self.__config.get("ctx") } + if "token_uri" in credentials and credentials.get("token_uri"): + options["token_uri"] = credentials.get("token_uri") if self.__bearer_token is None or self.__is_config_updated: if 'path' in credentials: @@ -85,7 +88,7 @@ def get_bearer_token(self, credentials): if is_expired(self.__bearer_token): self.__is_config_updated = True - raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) return self.__bearer_token diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index 7ffb36df..ca82527a 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -143,4 +143,73 @@ def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): credentials_string = '{' with self.assertRaises(SkyflowError) as context: result = generate_signed_data_tokens_from_creds(credentials_string, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) \ No newline at end of file + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + options = {"role_ids": ["role1", "role2"]} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token", + "token_type": "bearer"}) + access_token, token_type = get_service_account_token(creds, options, None) + self.assertEqual(access_token, "token") + self.assertEqual(token_type, "bearer") + args, kwargs = mock_auth_api.authentication_service_get_auth_token.call_args + self.assertIn("scope", kwargs) + self.assertEqual(kwargs["scope"], "role:role1 role:role2") + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError + mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.side_effect = Exception("some error") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value) + + @patch("jwt.encode", side_effect=Exception("jwt error")) + def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com" + } + options = {"data_tokens": ["token1"]} + with self.assertRaises(SkyflowError) as context: + from skyflow.service_account._utils import get_signed_tokens + get_signed_tokens(creds, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) \ No newline at end of file diff --git a/tests/utils/test__helpers.py b/tests/utils/test__helpers.py index 8b55abf3..6758b62e 100644 --- a/tests/utils/test__helpers.py +++ b/tests/utils/test__helpers.py @@ -1,5 +1,5 @@ import unittest -from skyflow.utils import get_base_url, format_scope +from skyflow.utils import get_base_url, format_scope, is_valid_url VALID_URL = "https://example.com/path?query=1" BASE_URL = "https://example.com" @@ -35,4 +35,27 @@ def test_format_scope_single_scope(self): def test_format_scope_special_characters(self): scopes_with_special_chars = ["admin", "user:write", "read-only"] expected_result = "role:admin role:user:write role:read-only" - self.assertEqual(format_scope(scopes_with_special_chars), expected_result) \ No newline at end of file + self.assertEqual(format_scope(scopes_with_special_chars), expected_result) + + def test_is_valid_url_valid(self): + self.assertTrue(is_valid_url("https://example.com")) + self.assertTrue(is_valid_url("http://example.com/path")) + + def test_is_valid_url_invalid(self): + self.assertFalse(is_valid_url("ftp://example.com")) + self.assertFalse(is_valid_url("example.com")) + self.assertFalse(is_valid_url("invalid-url")) + self.assertFalse(is_valid_url("")) + + def test_is_valid_url_none(self): + self.assertFalse(is_valid_url(None)) + + def test_is_valid_url_no_scheme(self): + self.assertFalse(is_valid_url("www.example.com")) + + def test_is_valid_url_exception(self): + class BadStr: + def __str__(self): + raise Exception("bad str") + + self.assertFalse(is_valid_url(BadStr())) \ No newline at end of file diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6eaacf47..55d8c00e 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1,13 +1,15 @@ import unittest from unittest.mock import patch, Mock import os -import json from unittest.mock import MagicMock from urllib.parse import quote +import tempfile, json from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.generated.rest import ErrorResponse +from skyflow.service_account import generate_bearer_token, generate_signed_data_tokens, \ + generate_signed_data_tokens_from_creds, generate_bearer_token_from_creds from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ @@ -597,3 +599,189 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.text_index.end, 0) self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) + + def test_generate_bearer_token_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 # invalid type + } + + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + # Patch AuthClient and jwt.encode to avoid real HTTP and signing + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), + {"access_token": "token", + "token_type": "bearer"}) + generate_bearer_token(tmp.name, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_bearer_token_from_creds_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 + } + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), + {"access_token": "token", + "token_type": "bearer"}) + generate_bearer_token_from_creds(creds_str, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_signed_data_tokens_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 + } + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 12345 + } + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'not_a_url' + } + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens(tmp.name, options) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0], "token1") + self.assertEqual(result[1], "signed_token_signed") + + def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self): + creds = { + 'privateKey': 'private_key', + 'clientID': 'client_id', + 'keyID': 'key_id', + 'tokenURI': 'https://valid-url.com' + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens_from_creds(creds_str, options) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0], "token1") + self.assertEqual(result[1], "signed_token_signed") diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 48332a55..5c3bb450 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -205,15 +205,6 @@ def test_validate_update_vault_config_valid(self): } self.assertTrue(validate_update_vault_config(self.logger, config)) - def test_validate_update_vault_config_missing_credentials(self): - config = { - "vault_id": "vault123", - "cluster_id": "cluster123" - } - with self.assertRaises(SkyflowError) as context: - validate_update_vault_config(self.logger, config) - self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123")) - def test_validate_update_vault_config_invalid_cluster_id(self): config = { "vault_id": "vault123", @@ -1044,3 +1035,69 @@ def test_validate_detokenize_request_invalid_redaction_type(self): with self.assertRaises(SkyflowError) as context: validate_detokenize_request(self.logger, request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) + + def test_validate_credentials_with_valid_token_uri(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + } + # Should not raise + validate_credentials(self.logger, credentials) + + def test_validate_credentials_with_invalid_token_uri_type(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 # Not a string + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_credentials_with_invalid_token_uri_url(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_valid_token_uri(self): + from skyflow.utils.enums import Env + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + }, + "env": Env.DEV + } + # Should not raise + self.assertTrue(validate_update_vault_config(self.logger, config)) + + def test_validate_update_vault_config_with_invalid_token_uri_type(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_invalid_token_uri_url(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 565b1e6f..b4d6ec42 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -97,4 +97,24 @@ def test_get_log_level(self): def test_get_logger(self): mock_logger = MagicMock() self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) \ No newline at end of file + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + def test_get_bearer_token_with_token(self): + credentials = {"token": "dummy_token"} + token = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token, "dummy_token") + + def test_get_bearer_token_with_token_uri_in_credentials(self): + credentials = { + "path": "dummy_path", + "token_uri": "https://valid-url.com" + } + with patch("skyflow.vault.client.client.generate_bearer_token") as mock_generate_bearer_token, \ + patch("skyflow.vault.client.client.is_expired", return_value=False): + mock_generate_bearer_token.return_value = ("bearer_token", "bearer") + token = self.vault_client.get_bearer_token(credentials) + mock_generate_bearer_token.assert_called_once() + args, kwargs = mock_generate_bearer_token.call_args + self.assertIn("token_uri", args[1]) + self.assertEqual(args[1]["token_uri"], "https://valid-url.com") + self.assertEqual(token, "bearer_token") \ No newline at end of file From 3294b09f19ca35e595ddae82ce24c3a70f308d0f Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Tue, 27 Jan 2026 16:11:14 +0000 Subject: [PATCH 02/23] [AUTOMATED] Private Release 2.0.0.dev0+ec98d8f --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 09f844d2..9ab81074 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0' +current_version = '2.0.0.dev0+ec98d8f' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 0d05fc30..16fcec0b 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0' \ No newline at end of file +SDK_VERSION = '2.0.0.dev0+ec98d8f' \ No newline at end of file From af0b616316c2b65766dd121aec71e1ccc7f92dd5 Mon Sep 17 00:00:00 2001 From: skyflow-himanshu Date: Wed, 28 Jan 2026 18:20:15 +0530 Subject: [PATCH 03/23] SK-2496: extract hard coded values to constants --- skyflow/client/skyflow.py | 55 ++++---- skyflow/service_account/_utils.py | 69 ++++----- skyflow/utils/_skyflow_messages.py | 3 + skyflow/utils/_utils.py | 109 ++++++++------- skyflow/utils/constants.py | 163 ++++++++++++++++++++++ skyflow/utils/logger/_log_helpers.py | 13 +- skyflow/utils/validations/_validations.py | 9 +- skyflow/vault/client/client.py | 29 ++-- skyflow/vault/controller/_connections.py | 5 +- skyflow/vault/controller/_detect.py | 39 +++--- skyflow/vault/controller/_vault.py | 12 +- 11 files changed, 340 insertions(+), 166 deletions(-) diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 9f0d9dbf..0bfde34e 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -3,6 +3,7 @@ from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow.utils.logger import log_info, Logger +from skyflow.utils.constants import OptionField from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \ validate_update_connection_config, validate_credentials, validate_log_level from skyflow.vault.client.client import VaultClient @@ -30,7 +31,7 @@ def update_vault_config(self,config): self.__builder.update_vault_config(config) def get_vault_config(self, vault_id): - return self.__builder.get_vault_config(vault_id).get("vault_client").get_config() + return self.__builder.get_vault_config(vault_id).get(OptionField.VAULT_CLIENT).get_config() def add_connection_config(self, config): self.__builder._Builder__add_connection_config(config) @@ -45,7 +46,7 @@ def update_connection_config(self, config): return self def get_connection_config(self, connection_id): - return self.__builder.get_connection_config(connection_id).get("vault_client").get_config() + return self.__builder.get_connection_config(connection_id).get(OptionField.VAULT_CLIENT).get_config() def add_skyflow_credentials(self, credentials): self.__builder._Builder__add_skyflow_credentials(credentials) @@ -66,15 +67,15 @@ def update_log_level(self, log_level): def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("vault_controller") + return vault_config.get(OptionField.VAULT_CONTROLLER) def connection(self, connection_id = None) -> Connection: connection_config = self.__builder.get_connection_config(connection_id) - return connection_config.get("controller") + return connection_config.get(OptionField.CONTROLLER) def detect(self, vault_id = None) -> Detect: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("detect_controller") + return vault_config.get(OptionField.DETECT_CONTROLLER) class Builder: def __init__(self): @@ -87,13 +88,13 @@ def __init__(self): self.__logger = Logger(LogLevel.ERROR) def add_vault_config(self, config): - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) if not isinstance(vault_id, str) or not vault_id: raise SkyflowError( SkyflowMessages.Error.INVALID_VAULT_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if vault_id in [vault.get("vault_id") for vault in self.__vault_list]: + if vault_id in [vault.get(OptionField.VAULT_ID) for vault in self.__vault_list]: log_info(SkyflowMessages.Info.VAULT_CONFIG_EXISTS.value.format(vault_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(vault_id), @@ -112,9 +113,9 @@ def remove_vault_config(self, vault_id): def update_vault_config(self, config): validate_update_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_config = self.__vault_configs[vault_id] - vault_config.get("vault_client").update_config(config) + vault_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_vault_config(self, vault_id): if vault_id is None: @@ -129,13 +130,13 @@ def get_vault_config(self, vault_id): def add_connection_config(self, config): - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) if not isinstance(connection_id, str) or not connection_id: raise SkyflowError( SkyflowMessages.Error.INVALID_CONNECTION_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if connection_id in [connection.get("connection_id") for connection in self.__connection_list]: + if connection_id in [connection.get(OptionField.CONNECTION_ID) for connection in self.__connection_list]: log_info(SkyflowMessages.Info.CONNECTION_CONFIG_EXISTS.value.format(connection_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id), @@ -153,9 +154,9 @@ def remove_connection_config(self, connection_id): def update_connection_config(self, config): validate_update_connection_config(self.__logger, config) - connection_id = config['connection_id'] + connection_id = config[OptionField.CONNECTION_ID] connection_config = self.__connection_configs[connection_id] - connection_config.get("vault_client").update_config(config) + connection_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_connection_config(self, connection_id): if connection_id is None: @@ -183,32 +184,32 @@ def get_logger(self): def __add_vault_config(self, config): validate_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_client = VaultClient(config) self.__vault_configs[vault_id] = { - "vault_client": vault_client, - "vault_controller": Vault(vault_client), - "detect_controller": Detect(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.VAULT_CONTROLLER: Vault(vault_client), + OptionField.DETECT_CONTROLLER: Detect(vault_client) } - log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) - log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) + log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) + log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) def __add_connection_config(self, config): validate_connection_config(self.__logger, config) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) vault_client = VaultClient(config) self.__connection_configs[connection_id] = { - "vault_client": vault_client, - "controller": Connection(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.CONTROLLER: Connection(vault_client) } - log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get("connection_id")), self.__logger) + log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.CONNECTION_ID)), self.__logger) def __update_vault_client_logger(self, log_level, logger): for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_logger(log_level,logger) + vault_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_logger(log_level,logger) + connection_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) def __set_log_level(self, log_level): validate_log_level(self.__logger, log_level) @@ -223,10 +224,10 @@ def __add_skyflow_credentials(self, credentials): self.__skyflow_credentials = credentials validate_credentials(self.__logger, credentials) for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_common_skyflow_credentials(credentials) + vault_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(credentials) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials) + connection_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(self.__skyflow_credentials) def build(self): validate_log_level(self.__logger, self.__log_level) self.__logger.set_log_level(self.__log_level) diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 715716d8..3f21ba21 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -6,6 +6,7 @@ from skyflow.service_account.client.auth_client import AuthClient from skyflow.utils.logger import log_info, log_error_log from skyflow.utils import get_base_url, format_scope, SkyflowMessages +from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value @@ -17,8 +18,8 @@ def is_expired(token, logger = None): try: decoded = jwt.decode( - token, options={"verify_signature": False, "verify_aud": False}) - if time.time() >= decoded['exp']: + token, options={OptionField.VERIFY_SIGNATURE: False, OptionField.VERIFY_AUD: False}) + if time.time() >= decoded[JwtField.EXP]: log_info(SkyflowMessages.Info.BEARER_TOKEN_EXPIRED.value, logger) log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -59,22 +60,22 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) def get_service_account_token(credentials, options, logger): try: - private_key = credentials["privateKey"] + private_key = credentials[CredentialField.PRIVATE_KEY] except: log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: - client_id = credentials["clientID"] + client_id = credentials[CredentialField.CLIENT_ID] except: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: - key_id = credentials["keyID"] + key_id = credentials[CredentialField.KEY_ID] except: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: - token_uri = credentials["tokenURI"] + token_uri = credentials[CredentialField.TOKEN_URI] except: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) @@ -85,27 +86,27 @@ def get_service_account_token(credentials, options, logger): auth_api = auth_client.get_auth_api() formatted_scope = None - if options and "role_ids" in options: - formatted_scope = format_scope(options.get("role_ids")) + if options and OptionField.ROLE_IDS in options: + formatted_scope = format_scope(options.get(OptionField.ROLE_IDS)) response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) return response.access_token, response.token_type def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): payload = { - "iss": client_id, - "key": key_id, - "aud": token_uri, - "sub": client_id, - "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60) + JwtField.ISS: client_id, + JwtField.KEY: key_id, + JwtField.AUD: token_uri, + JwtField.SUB: client_id, + JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and "ctx" in options: - payload["ctx"] = options.get("ctx") + if options and JwtField.CTX in options: + payload[JwtField.CTX] = options.get(JwtField.CTX) try: - return jwt.encode(payload=payload, key=private_key, algorithm="RS256") + return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code) @@ -113,25 +114,25 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): try: - expiry_time = int(time.time()) + options.get("time_to_live", 60) - prefix = "signed_token_" + expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) + prefix = JWT.SIGNED_TOKEN_PREFIX - if options and options.get("data_tokens"): - for token in options["data_tokens"]: + if options and options.get(OptionField.DATA_TOKENS): + for token in options[OptionField.DATA_TOKENS]: claims = { - "iss": "sdk", - "key": credentials_obj.get("keyID"), - "exp": expiry_time, - "sub": credentials_obj.get("clientID"), - "tok": token, - "iat": int(time.time()), + JwtField.ISS: JWT.ISSUER_SDK, + JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID), + JwtField.EXP: expiry_time, + JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID), + JwtField.TOK: token, + JwtField.IAT: int(time.time()), } - if "ctx" in options: - claims["ctx"] = options["ctx"] + if JwtField.CTX in options: + claims[JwtField.CTX] = options[JwtField.CTX] - private_key = credentials_obj.get("privateKey") - signed_jwt = jwt.encode(claims, private_key, algorithm="RS256") + private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) + signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) return response_object @@ -170,7 +171,7 @@ def generate_signed_data_tokens_from_creds(credentials, options): def get_signed_data_token_response_object(signed_token, actual_token): response_object = { - "token": actual_token, - "signed_token": signed_token + ResponseField.TOKEN: actual_token, + ResponseField.SIGNED_TOKEN: signed_token } - return response_object.get("token"), response_object.get("signed_token") + return response_object.get(ResponseField.TOKEN), response_object.get(ResponseField.SIGNED_TOKEN) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..1954ed4d 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -71,6 +71,9 @@ class Error(Enum): RESPONSE_NOT_JSON = f"{error_prefix} Response {{}} is not valid JSON." API_ERROR = f"{error_prefix} Server returned status code {{}}" + INVALID_JSON_RESPONSE = f"{error_prefix} Invalid JSON response received." + UNKNOWN_ERROR_DEFAULT_MESSAGE = f"{error_prefix} An unknown error occurred." + INVALID_FILE_INPUT = f"{error_prefix} Validation error. Invalid file input. Specify a valid file input." INVALID_DETECT_ENTITIES_TYPE = f"{error_prefix} Validation error. Invalid type of detect entities. Specify detect entities as list of DetectEntities enum." INVALID_TYPE_FOR_DEFAULT_TOKEN_TYPE = f"{error_prefix} Validation error. Invalid type of default token type. Specify default token type as TokenType enum." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 4278357e..c6f294cd 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -20,7 +20,8 @@ from skyflow.vault.detect import DeidentifyTextResponse, ReidentifyTextResponse from skyflow.vault.detect import EntityInfo, TextIndex from . import SkyflowMessages, SDK_VERSION -from .constants import PROTOCOL +from .constants import (PROTOCOL, HttpHeader, ApiKey, ContentType as ContentTypeConstants, + EncodingType, BooleanString, ResponseField, CredentialField) from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params @@ -44,7 +45,7 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg try: env_creds = env_skyflow_credentials.replace('\n', '\\n') return { - 'credentials_string': env_creds + CredentialField.CREDENTIALS_STRING: env_creds } except json.JSONDecodeError: raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) @@ -52,7 +53,7 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False api_key_pattern = re.compile(r'^sky-[a-zA-Z0-9]{5}-[a-fA-F0-9]{32}$') @@ -113,13 +114,13 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - if not 'Content-Type'.lower() in header: - header['content-type'] = ContentType.JSON.value + if not HttpHeader.CONTENT_TYPE.lower() in header: + header[HttpHeader.CONTENT_TYPE_LOWERCASE] = ContentType.JSON.value try: if isinstance(request.body, dict): json_data, files = get_data_from_content_type( - request.body, header["content-type"] + request.body, header[HttpHeader.CONTENT_TYPE_LOWERCASE] ) else: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) @@ -216,30 +217,30 @@ def parse_insert_response(api_response, continue_on_error): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) inserted_fields = [] errors = [] insert_response = InsertResponse() if continue_on_error: for idx, response in enumerate(api_response_data.responses): - if response['Status'] == 200: - body = response['Body'] - if 'records' in body: - for record in body['records']: + if response[ResponseField.STATUS] == 200: + body = response[ResponseField.BODY] + if ResponseField.RECORDS in body: + for record in body[ResponseField.RECORDS]: inserted_field = { - 'skyflow_id': record['skyflow_id'], - 'request_index': idx + ResponseField.SKYFLOW_ID: record[ResponseField.SKYFLOW_ID], + ResponseField.REQUEST_INDEX: idx } - if 'tokens' in record: - inserted_field.update(record['tokens']) + if ResponseField.TOKENS in record: + inserted_field.update(record[ResponseField.TOKENS]) inserted_fields.append(inserted_field) - elif response['Status'] == 400: + elif response[ResponseField.STATUS] == 400: error = { - 'request_index': idx, - 'request_id': request_id, - 'error': response['Body']['error'], - 'http_code': response['Status'], + ResponseField.REQUEST_INDEX: idx, + ResponseField.REQUEST_ID: request_id, + ResponseField.ERROR: response[ResponseField.BODY][ResponseField.ERROR], + ResponseField.HTTP_CODE: response[ResponseField.STATUS], } errors.append(error) @@ -248,7 +249,7 @@ def parse_insert_response(api_response, continue_on_error): else: for record in api_response_data.records: field_data = { - 'skyflow_id': record.skyflow_id + ResponseField.SKYFLOW_ID: record.skyflow_id } if record.tokens: @@ -263,7 +264,7 @@ def parse_insert_response(api_response, continue_on_error): def parse_update_record_response(api_response: V1UpdateRecordResponse): update_response = UpdateResponse() updated_field = dict() - updated_field['skyflow_id'] = api_response.skyflow_id + updated_field[ResponseField.SKYFLOW_ID] = api_response.skyflow_id if api_response.tokens is not None: updated_field.update(api_response.tokens) @@ -293,23 +294,23 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) detokenized_fields = [] errors = [] for record in api_response_data.records: if record.error: errors.append({ - "token": record.token, - "error": record.error, - "request_id": request_id + ResponseField.TOKEN: record.token, + ResponseField.ERROR: record.error, + ResponseField.REQUEST_ID: request_id }) else: value_type = record.value_type if record.value_type else None detokenized_fields.append({ - "token": record.token, - "value": record.value, - "type": value_type + ResponseField.TOKEN: record.token, + ResponseField.VALUE: record.value, + ResponseField.TYPE: value_type }) detokenized_fields = detokenized_fields @@ -322,7 +323,7 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): def parse_tokenize_response(api_response: V1TokenizeResponse): tokenize_response = TokenizeResponse() - tokenized_fields = [{"token": record.token} for record in api_response.records] + tokenized_fields = [{ResponseField.TOKEN: record.token} for record in api_response.records] tokenize_response.tokenized_fields = tokenized_fields @@ -334,7 +335,7 @@ def parse_query_response(api_response: V1GetQueryResponse): for record in api_response.records: field_object = { **record.fields, - "tokenized_data": {} + ResponseField.TOKENIZED_DATA: {} } fields.append(field_object) query_response.fields = fields @@ -344,14 +345,14 @@ def parse_invoke_connection_response(api_response: requests.Response): status_code = api_response.status_code content = api_response.content if isinstance(content, bytes): - content = content.decode('utf-8') + content = content.decode(EncodingType.UTF_8) try: api_response.raise_for_status() try: data = json.loads(content) metadata = {} - if 'x-request-id' in api_response.headers: - metadata['request_id'] = api_response.headers['x-request-id'] + if HttpHeader.X_REQUEST_ID in api_response.headers: + metadata['request_id'] = api_response.headers[HttpHeader.X_REQUEST_ID] return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) except Exception as e: @@ -360,19 +361,19 @@ def parse_invoke_connection_response(api_response: requests.Response): message = SkyflowMessages.Error.API_ERROR.value.format(status_code) try: error_response = json.loads(content) - request_id = api_response.headers['x-request-id'] - error_from_client = api_response.headers.get('error-from-client') - - status_code = error_response.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = error_response.get('error', {}).get('http_status') - grpc_code = error_response.get('error', {}).get('grpc_code') - details = error_response.get('error', {}).get('details') - message = error_response.get('error', {}).get('message', "An unknown error occurred.") + request_id = api_response.headers[HttpHeader.X_REQUEST_ID] + error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) + + status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) + grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) + details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS) + message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) if error_from_client is not None: if details is None: details = [] - error_from_client_bool = error_from_client.lower() == 'true' - details.append({'error_from_client': error_from_client_bool}) + error_from_client_bool = error_from_client.lower() == BooleanString.TRUE + details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) except json.JSONDecodeError: @@ -399,14 +400,14 @@ def handle_exception(error, logger): if (isinstance(error, httpx.ConnectError)): handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) - request_id = error.headers.get('x-request-id', 'unknown-request-id') - content_type = error.headers.get('content-type') + request_id = error.headers.get(HttpHeader.X_REQUEST_ID, 'unknown-request-id') + content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) data = error.body if content_type: - if 'application/json' in content_type: + if ContentTypeConstants.APPLICATION_JSON in content_type: handle_json_error(error, data, request_id, logger) - elif 'text/plain' in content_type: + elif ContentTypeConstants.TEXT_PLAIN in content_type: handle_text_error(error, data, request_id, logger) else: handle_generic_error(error, request_id, logger) @@ -421,15 +422,15 @@ def handle_json_error(err, data, request_id, logger): description = data.dict() else: description = json.loads(data) - status_code = description.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = description.get('error', {}).get('http_status') - grpc_code = description.get('error', {}).get('grpc_code') - details = description.get('error', {}).get('details', []) + status_code = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + http_status = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) + grpc_code = description.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) + details = description.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS, []) - description_message = description.get('error', {}).get('message', "An unknown error occurred.") + description_message = description.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger = logger) except json.JSONDecodeError: - log_and_reject_error("Invalid JSON response received.", err, request_id, logger = logger) + log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger = logger) def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index ef20faf8..30cb124d 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -2,3 +2,166 @@ PROTOCOL='https' SKY_META_DATA_HEADER='sky-metadata' +class SKYFLOW: + SKYFLOW_ID = 'skyflowId' + X_SKYFLOW_AUTHORIZATION = 'x-skyflow-authorization' + + +class HttpHeader: + CONTENT_TYPE = 'Content-Type' + CONTENT_TYPE_LOWERCASE = 'content-type' + X_REQUEST_ID = 'x-request-id' + ERROR_FROM_CLIENT = 'error-from-client' + AUTHORIZATION = 'Authorization' + + +class HttpStatusCode: + OK = 200 + BAD_REQUEST = 400 + INTERNAL_SERVER_ERROR = 500 + + +class ContentType: + APPLICATION_JSON = 'application/json' + APPLICATION_X_WWW_FORM_URLENCODED = 'application/x-www-form-urlencoded' + TEXT_PLAIN = 'text/plain' + + +class DetectStatus: + IN_PROGRESS = 'IN_PROGRESS' + SUCCESS = 'SUCCESS' + FAILED = 'FAILED' + UNKNOWN = 'UNKNOWN' + + +class FileExtension: + JSON = 'json' + MP3 = 'mp3' + WAV = 'wav' + PDF = 'pdf' + TXT = 'txt' + DOC = 'doc' + DOCX = 'docx' + JPG = 'jpg' + JPEG = 'jpeg' + PNG = 'png' + BMP = 'bmp' + TIF = 'tif' + TIFF = 'tiff' + PPT = 'ppt' + PPTX = 'pptx' + CSV = 'csv' + XLS = 'xls' + XLSX = 'xlsx' + XML = 'xml' + + +class FileProcessing: + PROCESSED_PREFIX = 'processed-' + DEIDENTIFIED_PREFIX = 'deidentified.' + ENTITIES = 'entities' + + +class EncodingType: + UTF8 = 'utf8' + UTF_8 = 'utf-8' + BASE64 = 'base64' + BINARY = 'binary' + + +class JWT: + ALGORITHM_RS256 = 'RS256' + GRANT_TYPE_JWT_BEARER = 'urn:ietf:params:oauth:grant-type:jwt-bearer' + ISSUER_SDK = 'sdk' + SIGNED_TOKEN_PREFIX = 'signed_token_' + ROLE_PREFIX = 'role:' + + +class ApiKey: + SKY_PREFIX = 'sky-' + LENGTH = 42 + + +class UrlProtocol: + HTTPS = 'https' + HTTP = 'http' + + +class BooleanString: + TRUE = 'true' + FALSE = 'false' + + +class ResponseField: + STATUS = 'Status' + BODY = 'Body' + RECORDS = 'records' + TOKENS = 'tokens' + ERROR = 'error' + SKYFLOW_ID = 'skyflow_id' + REQUEST_INDEX = 'request_index' + REQUEST_ID = 'request_id' + HTTP_CODE = 'http_code' + HTTP_STATUS = 'http_status' + GRPC_CODE = 'grpc_code' + DETAILS = 'details' + MESSAGE = 'message' + ERROR_FROM_CLIENT = 'error_from_client' + TOKEN = 'token' + VALUE = 'value' + TYPE = 'type' + TOKENIZED_DATA = 'tokenized_data' + SIGNED_TOKEN = 'signed_token' + + +class CredentialField: + PRIVATE_KEY = 'privateKey' + CLIENT_ID = 'clientID' + KEY_ID = 'keyID' + TOKEN_URI = 'tokenURI' + CREDENTIALS_STRING = 'credentials_string' + API_KEY = 'api_key' + TOKEN = 'token' + PATH = 'path' + + +class JwtField: + ISS = 'iss' + KEY = 'key' + AUD = 'aud' + SUB = 'sub' + EXP = 'exp' + CTX = 'ctx' + TOK = 'tok' + IAT = 'iat' + + +class OptionField: + ROLE_IDS = 'role_ids' + DATA_TOKENS = 'data_tokens' + TIME_TO_LIVE = 'time_to_live' + ROLES = 'roles' + CTX = 'ctx' + VAULT_ID = 'vault_id' + CONNECTION_ID = 'connection_id' + CONNECTION_URL = 'connection_url' + VAULT_CLIENT = 'vault_client' + VAULT_CONTROLLER = 'vault_controller' + DETECT_CONTROLLER = 'detect_controller' + CONTROLLER = 'controller' + VERIFY_SIGNATURE = 'verify_signature' + VERIFY_AUD = 'verify_aud' + + +class ConfigField: + CREDENTIALS = 'credentials' + CLUSTER_ID = 'cluster_id' + ENV = 'env' + VAULT_ID = 'vault_id' + + +class RequestParameter: + VALUE = 'value' + COLUMN_GROUP = 'column_group' + REDACTION = 'redaction' + diff --git a/skyflow/utils/logger/_log_helpers.py b/skyflow/utils/logger/_log_helpers.py index fdb11ea9..3fff980b 100644 --- a/skyflow/utils/logger/_log_helpers.py +++ b/skyflow/utils/logger/_log_helpers.py @@ -1,5 +1,6 @@ from ..enums import LogLevel from . import Logger +from ..constants import ResponseField def log_info(message, logger = None): @@ -18,17 +19,17 @@ def log_error(message, http_code, request_id=None, grpc_code=None, http_status=N logger = Logger(LogLevel.ERROR) log_data = { - 'http_code': http_code, - 'message': message + ResponseField.HTTP_CODE: http_code, + ResponseField.MESSAGE: message } if grpc_code is not None: - log_data['grpc_code'] = grpc_code + log_data[ResponseField.GRPC_CODE] = grpc_code if http_status is not None: - log_data['http_status'] = http_status + log_data[ResponseField.HTTP_STATUS] = http_status if request_id is not None: - log_data['request_id'] = request_id + log_data[ResponseField.REQUEST_ID] = request_id if details is not None: - log_data['details'] = details + log_data[ResponseField.DETAILS] = details logger.error(log_data) \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..779fdfcc 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -6,6 +6,7 @@ MaskingMethod from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages +from skyflow.utils.constants import ApiKey, ResponseField from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest @@ -50,11 +51,11 @@ def validate_required_field(logger, config, field_name, expected_type, empty_err raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if not api_key.startswith('sky-'): + if not api_key.startswith(ApiKey.SKY_PREFIX): log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger=logger) return False - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False @@ -582,10 +583,10 @@ def validate_get_request(logger, request): def validate_update_request(logger, request): skyflow_id = "" - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} try: - skyflow_id = request.data.get("skyflow_id") + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) except Exception: log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format("UPDATE"), logger=logger) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..2d77330e 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -2,6 +2,7 @@ from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages from skyflow.utils.logger import log_info +from skyflow.utils.constants import OptionField, CredentialField, ConfigField class VaultClient: @@ -23,11 +24,11 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger) + credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger = self.__logger) token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env"), - self.__config.get("vault_id"), + vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), logger = self.__logger) self.initialize_api_client(vault_url, token) @@ -50,29 +51,29 @@ def get_detect_file_api(self): return self.__api_client.files def get_vault_id(self): - return self.__config.get("vault_id") + return self.__config.get(ConfigField.VAULT_ID) def get_bearer_token(self, credentials): - if 'api_key' in credentials: - return credentials.get('api_key') - elif 'token' in credentials: - return credentials.get("token") + if CredentialField.API_KEY in credentials: + return credentials.get(CredentialField.API_KEY) + elif CredentialField.TOKEN in credentials: + return credentials.get(CredentialField.TOKEN) options = { - "role_ids": self.__config.get("roles"), - "ctx": self.__config.get("ctx") + OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), + OptionField.CTX: self.__config.get(OptionField.CTX) } if self.__bearer_token is None or self.__is_config_updated: - if 'path' in credentials: - path = credentials.get("path") + if CredentialField.PATH in credentials: + path = credentials.get(CredentialField.PATH) self.__bearer_token, _ = generate_bearer_token( path, options, self.__logger ) else: - credentials_string = credentials.get('credentials_string') + credentials_string = credentials.get(CredentialField.CREDENTIALS_STRING) log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, self.__logger) self.__bearer_token, _ = generate_bearer_token_from_creds( credentials_string, diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 81c6ea10..83b0ffbd 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,6 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW class Connection: @@ -23,9 +24,9 @@ def invoke(self, request: InvokeConnectionRequest): log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: - invoke_connection_request.headers['x-skyflow-authorization'] = bearer_token + invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token - invoke_connection_request.headers['sky-metadata'] = json.dumps(get_metrics()) + invoke_connection_request.headers[SKY_META_DATA_HEADER] = json.dumps(get_metrics()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 44ef2540..4f2f50f2 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,7 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -64,7 +65,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): while True: response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data status = response.status - if status == 'IN_PROGRESS': + if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS') else: @@ -76,7 +77,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): wait_time = next_wait_time current_wait_time = next_wait_time time.sleep(wait_time) - elif status == 'SUCCESS' or status == 'FAILED': + elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: raise e @@ -88,7 +89,7 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o if not os.path.exists(output_directory): return - deidentify_file_prefix = "processed-" + deidentify_file_prefix = FileProcessing.PROCESSED_PREFIX output_list = response.output base_original_filename = os.path.basename(original_file_name) @@ -159,7 +160,7 @@ def output_to_dict_list(output): output_list = output_to_dict_list(output) first_output = output_list[0] if output_list else {} - entities = [o for o in output_list if o.get("type") == "entities"] + entities = [o for o in output_list if o.get("type") == FileProcessing.ENTITIES] base64_string = first_output.get("file", None) extension = first_output.get("extension", None) @@ -167,14 +168,14 @@ def output_to_dict_list(output): if base64_string is not None: file_bytes = base64.b64decode(base64_string) file_obj = io.BytesIO(file_bytes) - file_obj.name = f"deidentified.{extension}" if extension else "processed_file" + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else "processed_file" else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get("type", "UNKNOWN"), + type=first_output.get("type", DetectStatus.UNKNOWN), extension=extension, word_count=word_count, char_count=char_count, @@ -282,11 +283,11 @@ def deidentify_file(self, request: DeidentifyFileRequest): file_name = getattr(file_obj, 'name', None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode('utf-8') + base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) try: - if file_extension == 'txt': - req_file = FileDataDeidentifyText(base_64=base64_string, data_format="txt") + if file_extension == FileExtension.TXT: + req_file = FileDataDeidentifyText(base_64=base64_string, data_format=FileExtension.TXT) api_call = files_api.deidentify_text api_kwargs = { 'vault_id': self.__vault_client.get_vault_id(), @@ -299,7 +300,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['mp3', 'wav']: + elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio api_kwargs = { @@ -319,7 +320,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension == 'pdf': + elif file_extension == FileExtension.PDF: req_file = FileDataDeidentifyPdf(base_64=base64_string) api_call = files_api.deidentify_pdf api_kwargs = { @@ -334,7 +335,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']: + elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: req_file = FileDataDeidentifyImage(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_image api_kwargs = { @@ -350,7 +351,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['ppt', 'pptx']: + elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: req_file = FileDataDeidentifyPresentation(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_presentation api_kwargs = { @@ -363,7 +364,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['csv', 'xls', 'xlsx']: + elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: req_file = FileDataDeidentifySpreadsheet(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_spreadsheet api_kwargs = { @@ -376,7 +377,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['doc', 'docx']: + elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: req_file = FileDataDeidentifyDocument(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_document api_kwargs = { @@ -389,7 +390,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['json', 'xml']: + elif file_extension in [FileExtension.JSON, FileExtension.XML]: req_file = FileDataDeidentifyStructuredText(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_structured_text api_kwargs = { @@ -423,7 +424,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): run_id = getattr(api_response.data, 'run_id', None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == 'SUCCESS': + if request.output_directory and processed_response.status == DetectStatus.SUCCESS: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -450,7 +451,7 @@ def get_detect_run(self, request: GetDetectRunRequest): vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers() ) - if response.data.status == 'IN_PROGRESS': + if response.data.status == DetectStatus.IN_PROGRESS: parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 7cc9ec77..a5cd94fd 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -8,7 +8,7 @@ from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter from skyflow.utils.enums import RequestMethod from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log @@ -125,7 +125,7 @@ def update(self, request: UpdateRequest): validate_update_request(self.__vault_client.get_logger(), request) log_info(SkyflowMessages.Info.UPDATE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} record = V1FieldRecords(fields=field, tokens = request.tokens) records_api = self.__vault_client.get_records_api() @@ -134,7 +134,7 @@ def update(self, request: UpdateRequest): api_response = records_api.record_service_update_record( self.__vault_client.get_vault_id(), request.table, - id=request.data.get("skyflow_id"), + id=request.data.get(ResponseField.SKYFLOW_ID), record=record, tokenization=request.return_tokens, byot=request.token_mode.value, @@ -225,8 +225,8 @@ def detokenize(self, request: DetokenizeRequest): self.__initialize() tokens_list = [ V1DetokenizeRecordRequest( - token=item.get('token'), - redaction=item.get('redaction', RedactionType.DEFAULT) + token=item.get(ResponseField.TOKEN), + redaction=item.get(RequestParameter.REDACTION, RedactionType.DEFAULT) ) for item in request.data ] @@ -253,7 +253,7 @@ def tokenize(self, request: TokenizeRequest): self.__initialize() records_list = [ - V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"]) + V1TokenizeRecordRequest(value=item[RequestParameter.VALUE], column_group=item[RequestParameter.COLUMN_GROUP]) for item in request.values ] tokens_api = self.__vault_client.get_tokens_api() From d17e71d2fd34134221232d9ad55506dd6b011e86 Mon Sep 17 00:00:00 2001 From: skyflow-himanshu Date: Mon, 2 Feb 2026 17:44:50 +0530 Subject: [PATCH 04/23] SK-2496: addressed review comments and suggestions --- ruff.toml | 2 +- skyflow/utils/_skyflow_messages.py | 1 + skyflow/utils/_utils.py | 23 +- skyflow/utils/constants.py | 115 +++++++++ skyflow/utils/validations/_validations.py | 288 ++++++++++++---------- skyflow/vault/controller/_connections.py | 4 +- skyflow/vault/controller/_detect.py | 274 ++++++++++---------- skyflow/vault/controller/_vault.py | 4 +- 8 files changed, 423 insertions(+), 288 deletions(-) diff --git a/ruff.toml b/ruff.toml index b6795704..8b0d5278 100644 --- a/ruff.toml +++ b/ruff.toml @@ -14,6 +14,6 @@ exclude = [ line-length = 120 [lint] -select = ["N"] +select = ["N", "PLR2004"] [lint.pep8-naming] diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 1954ed4d..21665972 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -389,6 +389,7 @@ class ErrorLogs(Enum): SAVING_DEIDENTIFY_FILE_FAILED = f"{ERROR}: [{error_prefix}] Error while saving deidentified file to output directory." REIDENTIFY_TEXT_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Reidentify text resulted in failure." DETECT_FILE_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Deidentify file resulted in failure." + EMPTY_FILE_COLUMN_NAME = f"{ERROR}: [{error_prefix}] Empty column name in FILE_UPLOAD" class Interface(Enum): INSERT = "INSERT" diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index c6f294cd..83c93b0c 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -21,7 +21,8 @@ from skyflow.vault.detect import EntityInfo, TextIndex from . import SkyflowMessages, SDK_VERSION from .constants import (PROTOCOL, HttpHeader, ApiKey, ContentType as ContentTypeConstants, - EncodingType, BooleanString, ResponseField, CredentialField) + EncodingType, BooleanString, ResponseField, CredentialField, SdkPrefix, + SdkMetricsKey, ErrorDefaults, HttpStatusCode) from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params @@ -129,7 +130,7 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep validate_invoke_connection_params(logger, request.query_params, request.path_params) - if not hasattr(request.method, 'value'): + if not hasattr(request.method, ResponseField.VALUE): raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_METHOD.value, invalid_input_error_code) try: @@ -187,7 +188,7 @@ def get_data_from_content_type(data, content_type): def get_metrics(): - sdk_name_version = "skyflow-python@" + SDK_VERSION + sdk_name_version = SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION try: sdk_client_device_model = platform.node() @@ -205,10 +206,10 @@ def get_metrics(): sdk_runtime_details = "" details_dic = { - 'sdk_name_version': sdk_name_version, - 'sdk_client_device_model': sdk_client_device_model, - 'sdk_client_os_details': sdk_client_os_details, - 'sdk_runtime_details': "Python " + sdk_runtime_details, + SdkMetricsKey.SDK_NAME_VERSION: sdk_name_version, + SdkMetricsKey.SDK_CLIENT_DEVICE_MODEL: sdk_client_device_model, + SdkMetricsKey.SDK_CLIENT_OS_DETAILS: sdk_client_os_details, + SdkMetricsKey.SDK_RUNTIME_DETAILS: SdkPrefix.PYTHON_RUNTIME + sdk_runtime_details, } return details_dic @@ -223,7 +224,7 @@ def parse_insert_response(api_response, continue_on_error): insert_response = InsertResponse() if continue_on_error: for idx, response in enumerate(api_response_data.responses): - if response[ResponseField.STATUS] == 200: + if response[ResponseField.STATUS] == HttpStatusCode.OK: body = response[ResponseField.BODY] if ResponseField.RECORDS in body: for record in body[ResponseField.RECORDS]: @@ -235,7 +236,7 @@ def parse_insert_response(api_response, continue_on_error): if ResponseField.TOKENS in record: inserted_field.update(record[ResponseField.TOKENS]) inserted_fields.append(inserted_field) - elif response[ResponseField.STATUS] == 400: + elif response[ResponseField.STATUS] == HttpStatusCode.BAD_REQUEST: error = { ResponseField.REQUEST_INDEX: idx, ResponseField.REQUEST_ID: request_id, @@ -352,7 +353,7 @@ def parse_invoke_connection_response(api_response: requests.Response): data = json.loads(content) metadata = {} if HttpHeader.X_REQUEST_ID in api_response.headers: - metadata['request_id'] = api_response.headers[HttpHeader.X_REQUEST_ID] + metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID] return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) except Exception as e: @@ -400,7 +401,7 @@ def handle_exception(error, logger): if (isinstance(error, httpx.ConnectError)): handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) - request_id = error.headers.get(HttpHeader.X_REQUEST_ID, 'unknown-request-id') + request_id = error.headers.get(HttpHeader.X_REQUEST_ID, ErrorDefaults.UNKNOWN_REQUEST_ID) content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) data = error.body diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 30cb124d..62aa4d11 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -13,11 +13,13 @@ class HttpHeader: X_REQUEST_ID = 'x-request-id' ERROR_FROM_CLIENT = 'error-from-client' AUTHORIZATION = 'Authorization' + X_SKYFLOW_AUTHORIZATION_HEADER = 'X-Skyflow-Authorization' class HttpStatusCode: OK = 200 BAD_REQUEST = 400 + UNAUTHORIZED = 401 INTERNAL_SERVER_ERROR = 500 @@ -123,6 +125,8 @@ class CredentialField: API_KEY = 'api_key' TOKEN = 'token' PATH = 'path' + CONTEXT = 'context' + ROLES = 'roles' class JwtField: @@ -165,3 +169,114 @@ class RequestParameter: COLUMN_GROUP = 'column_group' REDACTION = 'redaction' + +class FileUploadField: + TABLE = 'table' + SKYFLOW_ID = 'skyflow_id' + COLUMN_NAME = 'column_name' + FILE_PATH = 'file_path' + BASE64 = 'base64' + FILE_OBJECT = 'file_object' + FILE_NAME = 'file_name' + FILE = 'file' + NAME = 'name' + + +class DeidentifyFileRequestField: + ENTITIES = 'entities' + ALLOW_REGEX_LIST = 'allow_regex_list' + RESTRICT_REGEX_LIST = 'restrict_regex_list' + OUTPUT_PROCESSED_IMAGE = 'output_processed_image' + OUTPUT_OCR_TEXT = 'output_ocr_text' + MASKING_METHOD = 'masking_method' + PIXEL_DENSITY = 'pixel_density' + MAX_RESOLUTION = 'max_resolution' + OUTPUT_PROCESSED_AUDIO = 'output_processed_audio' + OUTPUT_TRANSCRIPTION = 'output_transcription' + BLEEP = 'bleep' + OUTPUT_DIRECTORY = 'output_directory' + WAIT_TIME = 'wait_time' + + +class DeidentifyField: + TEXT = 'text' + ENTITY_TYPES = 'entity_types' + TOKEN_TYPE = 'token_type' + ALLOW_REGEX = 'allow_regex' + RESTRICT_REGEX = 'restrict_regex' + TRANSFORMATIONS = 'transformations' + FORMAT = 'format' + OUTPUT = 'output' + STATUS = 'status' + RUN_ID = 'run_id' + WORD_CHARACTER_COUNT = 'word_character_count' + WORD_COUNT = 'word_count' + CHARACTER_COUNT = 'character_count' + SIZE = 'size' + DURATION = 'duration' + PAGES = 'pages' + SLIDES = 'slides' + PROCESSED_FILE = 'processed_file' + PROCESSED_FILE_TYPE = 'processed_file_type' + PROCESSED_FILE_EXTENSION = 'processed_file_extension' + REDACTED_FILE = 'redacted_file' + SHIFT_DATES = 'shift_dates' + DEFAULT = 'default' + ENTITY_UNQ_COUNTER = 'entity_unq_counter' + ENTITY_UNIQUE_COUNTER = 'entity_unique_counter' + ENTITY_ONLY = 'entity_only' + ENTITIES = 'entities' + MAX_DAYS = 'max_days' + MIN_DAYS = 'min_days' + MAX = 'max' + MIN = 'min' + FILE = 'file' + TYPE = 'type' + EXTENSION = 'extension' + IN_PROGRESS = 'IN_PROGRESS' + REQUEST_OPTIONS = 'request_options' + BLEEP_GAIN = 'bleep_gain' + BLEEP_FREQUENCY = 'bleep_frequency' + BLEEP_START_PADDING = 'bleep_start_padding' + BLEEP_STOP_PADDING = 'bleep_stop_padding' + DENSITY = 'density' + TOKEN_FORMAT = 'token_format' + PROCESSED_FILE_RESPONSE_KEY = 'processedFile' + PROCESSED_FILE_TYPE_RESPONSE_KEY = 'processedFileType' + PROCESSED_FILE_EXTENSION_RESPONSE_KEY = 'processedFileExtension' + + +class RequestOperation: + INSERT = 'INSERT' + DELETE = 'DELETE' + GET = 'GET' + UPDATE = 'UPDATE' + QUERY = 'QUERY' + TOKENIZE = 'TOKENIZE' + DETOKENIZE = 'DETOKENIZE' + FILE_UPLOAD = 'FILE_UPLOAD' + + +class ConfigType: + VAULT = 'vault' + CONNECTION = 'connection' + + +class SqlCommand: + SELECT = 'SELECT' + + +class SdkPrefix: + SKYFLOW_PYTHON = 'skyflow-python@' + PYTHON_RUNTIME = 'Python ' + + +class SdkMetricsKey: + SDK_NAME_VERSION = 'sdk_name_version' + SDK_CLIENT_DEVICE_MODEL = 'sdk_client_device_model' + SDK_CLIENT_OS_DETAILS = 'sdk_client_os_details' + SDK_RUNTIME_DETAILS = 'sdk_runtime_details' + + +class ErrorDefaults: + UNKNOWN_REQUEST_ID = 'unknown-request-id' diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 779fdfcc..2ac5783c 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -6,47 +6,66 @@ MaskingMethod from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages -from skyflow.utils.constants import ApiKey, ResponseField +from skyflow.utils.constants import ( + ApiKey, ResponseField, RequestParameter, + FileUploadField, + DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField +) from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest from skyflow.vault.detect._file_input import FileInput -valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] -valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] -valid_credentials_keys = ["path", "roles", "context", "token", "credentials_string"] +valid_vault_config_keys = [ + ConfigField.VAULT_ID, + ConfigField.CLUSTER_ID, + ConfigField.CREDENTIALS, + ConfigField.ENV +] +valid_connection_config_keys = [ + OptionField.CONNECTION_ID, + OptionField.CONNECTION_URL, + ConfigField.CREDENTIALS +] +valid_credentials_keys = [ + CredentialField.PATH, + CredentialField.ROLES, + CredentialField.CONTEXT, + CredentialField.TOKEN, + CredentialField.CREDENTIALS_STRING +] invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def validate_required_field(logger, config, field_name, expected_type, empty_error, invalid_error): field_value = config.get(field_name) if field_name not in config or not isinstance(field_value, expected_type): - if field_name == "vault_id": + if field_name == ConfigField.VAULT_ID: logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) - if field_name == "cluster_id": + if field_name == ConfigField.CLUSTER_ID: logger.error(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value) - if field_name == "connection_id": + if field_name == OptionField.CONNECTION_ID: logger.error(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value) - if field_name == "connection_url": + if field_name == OptionField.CONNECTION_URL: logger.error(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value) raise SkyflowError(invalid_error, invalid_input_error_code) if isinstance(field_value, str) and not field_value.strip(): - if field_name == "vault_id": + if field_name == ConfigField.VAULT_ID: logger.error(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value) - if field_name == "cluster_id": + if field_name == ConfigField.CLUSTER_ID: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value) - if field_name == "connection_id": + if field_name == OptionField.CONNECTION_ID: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value) - if field_name == "connection_url": + if field_name == OptionField.CONNECTION_URL: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value) - if field_name == "path": + if field_name == CredentialField.PATH: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value) - if field_name == "credentials_string": + if field_name == CredentialField.CREDENTIALS_STRING: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value) - if field_name == "token": + if field_name == CredentialField.TOKEN: logger.error(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value) - if field_name == "api_key": + if field_name == CredentialField.API_KEY: logger.error(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value) raise SkyflowError(empty_error, invalid_input_error_code) @@ -62,7 +81,7 @@ def validate_api_key(api_key: str, logger = None) -> bool: return True def validate_credentials(logger, credentials, config_id_type=None, config_id=None): - key_present = [k for k in ["path", "token", "credentials_string", "api_key"] if credentials.get(k)] + key_present = [k for k in [CredentialField.PATH, CredentialField.TOKEN, CredentialField.CREDENTIALS_STRING, CredentialField.API_KEY] if credentials.get(k)] if len(key_present) == 0: error_message = ( @@ -79,63 +98,63 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non ) raise SkyflowError(error_message, invalid_input_error_code) - if "roles" in credentials: + if CredentialField.ROLES in credentials: validate_required_field( - logger, credentials, "roles", list, + logger, credentials, CredentialField.ROLES, list, SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE.value, SkyflowMessages.Error.EMPTY_ROLES_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_ROLES.value ) - if "context" in credentials: + if CredentialField.CONTEXT in credentials: validate_required_field( - logger, credentials, "context", str, + logger, credentials, CredentialField.CONTEXT, str, SkyflowMessages.Error.EMPTY_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CONTEXT.value, SkyflowMessages.Error.INVALID_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CONTEXT.value ) - if "credentials_string" in credentials: + if CredentialField.CREDENTIALS_STRING in credentials: validate_required_field( - logger, credentials, "credentials_string", str, + logger, credentials, CredentialField.CREDENTIALS_STRING, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING.value, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value ) - elif "path" in credentials: + elif CredentialField.PATH in credentials: validate_required_field( - logger, credentials, "path", str, + logger, credentials, CredentialField.PATH, str, SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH.value, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value ) - elif "token" in credentials: + elif CredentialField.TOKEN in credentials: validate_required_field( - logger, credentials, "token", str, + logger, credentials, CredentialField.TOKEN, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) - if is_expired(credentials.get("token"), logger): + if is_expired(credentials.get(CredentialField.TOKEN), logger): raise SkyflowError( SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, invalid_input_error_code ) - elif "api_key" in credentials: + elif CredentialField.API_KEY in credentials: validate_required_field( - logger, credentials, "api_key", str, + logger, credentials, CredentialField.API_KEY, str, SkyflowMessages.Error.EMPTY_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_API_KEY.value, SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value ) - if not validate_api_key(credentials.get("api_key"), logger): + if not validate_api_key(credentials.get(CredentialField.API_KEY), logger): raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) @@ -158,27 +177,27 @@ def validate_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) # Validate cluster_id (string, not empty) validate_required_field( - logger, config, "cluster_id", str, + logger, config, ConfigField.CLUSTER_ID, str, SkyflowMessages.Error.EMPTY_CLUSTER_ID.value.format(vault_id), SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id) ) # Validate credentials (dict, not empty) - if "credentials" in config and not config.get("credentials"): - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS in config and not config.get(ConfigField.CREDENTIALS): + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - if "credentials" in config and config.get("credentials"): - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + if ConfigField.CREDENTIALS in config and config.get(ConfigField.CREDENTIALS): + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) # Validate env (optional, should be one of LogLevel values) - if "env" in config and config.get("env") not in Env: + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) @@ -190,23 +209,23 @@ def validate_update_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) - if "cluster_id" in config and not config.get("cluster_id"): + if ConfigField.CLUSTER_ID in config and not config.get(ConfigField.CLUSTER_ID): raise SkyflowError(SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id), invalid_input_error_code) - if "env" in config and config.get("env") not in Env: + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) return True @@ -215,23 +234,23 @@ def validate_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id" , str, + logger, config, OptionField.CONNECTION_ID , str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) return True @@ -240,22 +259,22 @@ def validate_update_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id", str, + logger, config, OptionField.CONNECTION_ID, str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials")) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS)) return True @@ -263,8 +282,8 @@ def validate_file_from_request(file_input: FileInput): if file_input is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - has_file = hasattr(file_input, 'file') and file_input.file is not None - has_file_path = hasattr(file_input, 'file_path') and file_input.file_path is not None + has_file = hasattr(file_input, FileUploadField.FILE) and file_input.file is not None + has_file_path = hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None # Must provide exactly one of file or file_path if (has_file and has_file_path) or (not has_file and not has_file_path): @@ -273,7 +292,7 @@ def validate_file_from_request(file_input: FileInput): if has_file: file = file_input.file # Validate file object has required attributes - if not hasattr(file, 'name') or not isinstance(file.name, str) or not file.name.strip(): + if not hasattr(file, FileUploadField.FILE_NAME) or not isinstance(file.name, str) or not file.name.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) # Validate file name @@ -290,14 +309,14 @@ def validate_file_from_request(file_input: FileInput): raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): - if not hasattr(request, 'file') or request.file is None: + if not hasattr(request, FileUploadField.FILE) or request.file is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) # Validate file input first validate_file_from_request(request.file) # Optional: entities - if hasattr(request, 'entities') and request.entities is not None: + if hasattr(request, DeidentifyFileRequestField.ENTITIES) and request.entities is not None: if not isinstance(request.entities, list): raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) @@ -305,12 +324,12 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) # Optional: allow_regex_list - if hasattr(request, 'allow_regex_list') and request.allow_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.ALLOW_REGEX_LIST) and request.allow_regex_list is not None: if not isinstance(request.allow_regex_list, list) or not all(isinstance(x, str) for x in request.allow_regex_list): raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Optional: restrict_regex_list - if hasattr(request, 'restrict_regex_list') and request.restrict_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.RESTRICT_REGEX_LIST) and request.restrict_regex_list is not None: if not isinstance(request.restrict_regex_list, list) or not all(isinstance(x, str) for x in request.restrict_regex_list): raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) @@ -323,43 +342,42 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) # Optional: output_processed_image - if hasattr(request, 'output_processed_image') and request.output_processed_image is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE) and request.output_processed_image is not None: if not isinstance(request.output_processed_image, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, invalid_input_error_code) # Optional: output_ocr_text - if hasattr(request, 'output_ocr_text') and request.output_ocr_text is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT) and request.output_ocr_text is not None: if not isinstance(request.output_ocr_text, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, invalid_input_error_code) # Optional: masking_method - # Optional: masking_method - if hasattr(request, 'masking_method') and request.masking_method is not None: + if hasattr(request, DeidentifyFileRequestField.MASKING_METHOD) and request.masking_method is not None: if not isinstance(request.masking_method, MaskingMethod): raise SkyflowError(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, invalid_input_error_code) # Optional: pixel_density - if hasattr(request, 'pixel_density') and request.pixel_density is not None: + if hasattr(request, DeidentifyFileRequestField.PIXEL_DENSITY) and request.pixel_density is not None: if not isinstance(request.pixel_density, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, invalid_input_error_code) # Optional: max_resolution - if hasattr(request, 'max_resolution') and request.max_resolution is not None: + if hasattr(request, DeidentifyFileRequestField.MAX_RESOLUTION) and request.max_resolution is not None: if not isinstance(request.max_resolution, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, invalid_input_error_code) # Optional: output_processed_audio - if hasattr(request, 'output_processed_audio') and request.output_processed_audio is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO) and request.output_processed_audio is not None: if not isinstance(request.output_processed_audio, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, invalid_input_error_code) # Optional: output_transcription - if hasattr(request, 'output_transcription') and request.output_transcription is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION) and request.output_transcription is not None: if not isinstance(request.output_transcription, DetectOutputTranscriptions): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, invalid_input_error_code) # Optional: bleep - if hasattr(request, 'bleep') and request.bleep is not None: + if hasattr(request, DeidentifyFileRequestField.BLEEP) and request.bleep is not None: if not isinstance(request.bleep, Bleep): raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, invalid_input_error_code) @@ -380,53 +398,53 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, invalid_input_error_code) # Optional: output_directory - if hasattr(request, 'output_directory') and request.output_directory is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_DIRECTORY) and request.output_directory is not None: if not isinstance(request.output_directory, str): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, invalid_input_error_code) if not os.path.isdir(request.output_directory): raise SkyflowError(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), invalid_input_error_code) # Optional: wait_time - if hasattr(request, 'wait_time') and request.wait_time is not None: + if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 and request.wait_time > 64: + if request.wait_time < 0 and request.wait_time > 64: # noqa: PLR2004 raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not isinstance(request.values, list) or not all(isinstance(v, dict) for v in request.values): - log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if not len(request.values): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format("INSERT"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code) for i, item in enumerate(request.values, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format(RequestOperation.INSERT, key), logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) if request.homogeneous is not None and not isinstance(request.homogeneous, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_HOMOGENEOUS_TYPE.value, invalid_input_error_code) if request.upsert and request.homogeneous: - log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), logger = logger) - raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), logger = logger) + raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), invalid_input_error_code) if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): @@ -442,15 +460,15 @@ def validate_insert_request(logger, request): for i, item in enumerate(request.tokens, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format("INSERT"), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format(RequestOperation.INSERT), logger=logger) if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format("INSERT", key), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format(RequestOperation.INSERT, key), logger=logger) if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -461,29 +479,29 @@ def validate_insert_request(logger, request): if request.token_mode == TokenMode.ENABLE_STRICT: if len(request.values) != len(request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) for v, t in zip(request.values, request.tokens): if set(v.keys()) != set(t.keys()): - log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format("INSERT"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) def validate_delete_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not request.ids: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code) def validate_query_request(logger, request): if not request.query: - log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format("QUERY"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) if not isinstance(request.query, str): @@ -491,10 +509,10 @@ def validate_query_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code) if not request.query.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format("QUERY"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not request.query.upper().startswith("SELECT"): + if not request.query.upper().startswith(SqlCommand.SELECT): command = request.query raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) @@ -509,23 +527,23 @@ def validate_get_request(logger, request): download_url = request.download_url if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not skyflow_ids and not column_name and not column_values: - log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) if skyflow_ids and (not isinstance(skyflow_ids, list) or not skyflow_ids): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_IDS_TYPE.value.format(type(skyflow_ids)), invalid_input_error_code) if skyflow_ids: for index, skyflow_id in enumerate(skyflow_ids): if skyflow_id is None or skyflow_id == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format("GET", index), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format(RequestOperation.GET, index), logger=logger) if not isinstance(request.return_tokens, bool): @@ -535,7 +553,7 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction_type)), invalid_input_error_code) if fields is not None and (not isinstance(fields, list) or not fields): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(fields)), invalid_input_error_code) if offset is not None and limit is not None: @@ -561,24 +579,24 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if request.return_tokens and redaction_type: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value, invalid_input_error_code) if (column_name or column_values) and request.return_tokens: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format("GET"), + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.TOKENS_GET_COLUMN_NOT_SUPPORTED.value, invalid_input_error_code) if column_values and not column_name: - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if column_name and not column_values: - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format(RequestOperation.GET), logger = logger) SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) if (column_name or column_values) and skyflow_ids: - log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code) def validate_update_request(logger, request): @@ -588,16 +606,16 @@ def validate_update_request(logger, request): try: skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) except Exception: - log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) if not skyflow_id.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format("UPDATE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger = logger) if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("UPDATE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not isinstance(request.return_tokens, bool): @@ -615,7 +633,7 @@ def validate_update_request(logger, request): if request.tokens: if not isinstance(request.tokens, dict) or not request.tokens: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -628,14 +646,14 @@ def validate_update_request(logger, request): if request.token_mode == TokenMode.ENABLE_STRICT: if len(field) != len(request.tokens): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) if set(field.keys()) != set(request.tokens.keys()): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError( SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, @@ -649,20 +667,20 @@ def validate_detokenize_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code) if not len(request.data): - log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format("DETOKENIZE"), logger = logger) - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("DETOKENIZE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format(RequestOperation.DETOKENIZE), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.DETOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code) for item in request.data: - if 'token' not in item: + if ResponseField.TOKEN not in item: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - token = item.get('token') - redaction = item.get('redaction', None) + token = item.get(ResponseField.TOKEN) + redaction = item.get(RequestParameter.REDACTION, None) if not isinstance(token, str) or not token: - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE"), + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format(RequestOperation.DETOKENIZE), invalid_input_error_code) if redaction is not None and not isinstance(redaction, RedactionType): @@ -681,16 +699,16 @@ def validate_tokenize_request(logger, request): if not isinstance(param, dict): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER.value.format(i, type(param)), invalid_input_error_code) - allowed_keys = {"value", "column_group"} + allowed_keys = {RequestParameter.VALUE, RequestParameter.COLUMN_GROUP} if set(param.keys()) != allowed_keys: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER_KEY.value.format(i), invalid_input_error_code) - if not param.get("value"): - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.VALUE): + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_VALUE.value.format(i), invalid_input_error_code) - if not param.get("column_group"): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.COLUMN_GROUP): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_COLUMN_GROUP.value.format(i), invalid_input_error_code) @@ -699,32 +717,32 @@ def validate_file_upload_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) # Table - table = getattr(request, "table", None) + table = getattr(request, FileUploadField.TABLE, None) if table is None: raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) elif table.strip() == "": raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) # Skyflow ID - skyflow_id = getattr(request, "skyflow_id", None) + skyflow_id = getattr(request, FileUploadField.SKYFLOW_ID, None) if skyflow_id is None: raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) elif skyflow_id.strip() == "": - raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD"), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format(RequestOperation.FILE_UPLOAD), invalid_input_error_code) # Column Name - column_name = getattr(request, "column_name", None) + column_name = getattr(request, FileUploadField.COLUMN_NAME, None) if column_name is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) elif column_name.strip() == "": - logger.error("Empty column name in FILE_UPLOAD") + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FILE_COLUMN_NAME.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) # File-related attributes - file_path = getattr(request, "file_path", None) - base64_str = getattr(request, "base64", None) - file_object = getattr(request, "file_object", None) - file_name = getattr(request, "file_name", None) + file_path = getattr(request, FileUploadField.FILE_PATH, None) + base64_str = getattr(request, FileUploadField.BASE64, None) + file_object = getattr(request, FileUploadField.FILE_OBJECT, None) + file_name = getattr(request, FileUploadField.FILE_NAME, None) # Check file_path first if present if not is_none_or_empty(file_path): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 83b0ffbd..ca8c7a1d 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,7 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest -from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader class Connection: @@ -23,7 +23,7 @@ def invoke(self, request: InvokeConnectionRequest): invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) - if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: + if not HttpHeader.X_SKYFLOW_AUTHORIZATION_HEADER.lower() in invoke_connection_request.headers: invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token invoke_connection_request.headers[SKY_META_DATA_HEADER] = json.dumps(get_metrics()) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 4f2f50f2..c6ef2fb1 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -9,7 +9,7 @@ from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, - FileProcessing, EncodingType) + FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -34,12 +34,12 @@ def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[ deidentify_text_body = {} parsed_entity_types = request.entities - deidentify_text_body['text'] = request.text - deidentify_text_body['entity_types'] = parsed_entity_types - deidentify_text_body['token_type'] = self.__get_token_format(request) - deidentify_text_body['allow_regex'] = request.allow_regex_list - deidentify_text_body['restrict_regex'] = request.restrict_regex_list - deidentify_text_body['transformations'] = self.__get_transformations(request) + deidentify_text_body[DeidentifyField.TEXT] = request.text + deidentify_text_body[DeidentifyField.ENTITY_TYPES] = parsed_entity_types + deidentify_text_body[DeidentifyField.TOKEN_TYPE] = self.__get_token_format(request) + deidentify_text_body[DeidentifyField.ALLOW_REGEX] = request.allow_regex_list + deidentify_text_body[DeidentifyField.RESTRICT_REGEX] = request.restrict_regex_list + deidentify_text_body[DeidentifyField.TRANSFORMATIONS] = self.__get_transformations(request) return deidentify_text_body @@ -50,8 +50,8 @@ def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[ plaintext=request.plain_text_entities ) reidentify_text_body = {} - reidentify_text_body['text'] = request.text - reidentify_text_body['format'] = parsed_format + reidentify_text_body[DeidentifyField.TEXT] = request.text + reidentify_text_body[DeidentifyField.FORMAT] = parsed_format return reidentify_text_body def _get_file_extension(self, filename: str): @@ -67,7 +67,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): status = response.status if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: - return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS') + return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: next_wait_time = current_wait_time * 2 if next_wait_time >= max_wait_time: @@ -83,7 +83,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): raise e def __save_deidentify_file_response_output(self, response: DetectRunsResponse, output_directory: str, original_file_name: str, name_without_ext: str): - if not response or not hasattr(response, 'output') or not response.output or not output_directory: + if not response or not hasattr(response, DeidentifyField.OUTPUT) or not response.output or not output_directory: return if not os.path.exists(output_directory): @@ -97,16 +97,16 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o for idx, output in enumerate(output_list): try: - processed_file = get_attribute(output, 'processedFile', 'processed_file') - processed_file_type = get_attribute(output, 'processedFileType', 'processed_file_type') - processed_file_extension = get_attribute(output, 'processedFileExtension', 'processed_file_extension') + processed_file = get_attribute(output, DeidentifyField.PROCESSED_FILE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE) + processed_file_type = get_attribute(output, DeidentifyField.PROCESSED_FILE_TYPE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_TYPE) + processed_file_extension = get_attribute(output, DeidentifyField.PROCESSED_FILE_EXTENSION_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_EXTENSION) if not processed_file: continue decoded_data = base64.b64decode(processed_file) - if idx == 0 or processed_file_type == 'redacted_file': + if idx == 0 or processed_file_type == DeidentifyField.REDACTED_FILE: output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) if processed_file_extension: output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") @@ -120,62 +120,62 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o handle_exception(e, self.__vault_client.get_logger()) def __parse_deidentify_file_response(self, data, run_id=None, status=None): - output = getattr(data, "output", []) - status_val = getattr(data, "status", None) or status - run_id_val = getattr(data, "run_id", None) or run_id + output = getattr(data, DeidentifyField.OUTPUT, []) + status_val = getattr(data, DeidentifyField.STATUS, None) or status + run_id_val = getattr(data, DeidentifyField.RUN_ID, None) or run_id word_count = None char_count = None - word_character_count = getattr(data, "word_character_count", None) + word_character_count = getattr(data, DeidentifyField.WORD_CHARACTER_COUNT, None) if word_character_count and isinstance(word_character_count, WordCharacterCount): - word_count = word_character_count.word_count - char_count = word_character_count.character_count + word_count = getattr(word_character_count, DeidentifyField.WORD_COUNT, None) + char_count = getattr(word_character_count, DeidentifyField.CHARACTER_COUNT, None) - size = getattr(data, "size", None) + size = getattr(data, DeidentifyField.SIZE, None) size = float(size) if size is not None else None - duration = getattr(data, "duration", None) - pages = getattr(data, "pages", None) - slides = getattr(data, "slides", None) + duration = getattr(data, DeidentifyField.DURATION, None) + pages = getattr(data, DeidentifyField.PAGES, None) + slides = getattr(data, DeidentifyField.SLIDES, None) def output_to_dict_list(output): result = [] for o in output: if isinstance(o, dict): result.append({ - "file": o.get("processed_file"), - "type": o.get("processed_file_type"), - "extension": o.get("processed_file_extension") + DeidentifyField.FILE: o.get(DeidentifyField.PROCESSED_FILE), + DeidentifyField.TYPE: o.get(DeidentifyField.PROCESSED_FILE_TYPE), + DeidentifyField.EXTENSION: o.get(DeidentifyField.PROCESSED_FILE_EXTENSION) }) else: result.append({ - "file": getattr(o, "processed_file", None), - "type": getattr(o, "processed_file_type", None), - "extension": getattr(o, "processed_file_extension", None) + DeidentifyField.FILE: getattr(o, DeidentifyField.PROCESSED_FILE, None), + DeidentifyField.TYPE: getattr(o, DeidentifyField.PROCESSED_FILE_TYPE, None), + DeidentifyField.EXTENSION: getattr(o, DeidentifyField.PROCESSED_FILE_EXTENSION, None) }) return result output_list = output_to_dict_list(output) first_output = output_list[0] if output_list else {} - entities = [o for o in output_list if o.get("type") == FileProcessing.ENTITIES] + entities = [o for o in output_list if o.get(DeidentifyField.TYPE) == FileProcessing.ENTITIES] - base64_string = first_output.get("file", None) - extension = first_output.get("extension", None) + base64_string = first_output.get(DeidentifyField.FILE, None) + extension = first_output.get(DeidentifyField.EXTENSION, None) if base64_string is not None: file_bytes = base64.b64decode(base64_string) file_obj = io.BytesIO(file_bytes) - file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else "processed_file" + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get("type", DetectStatus.UNKNOWN), + type=first_output.get(DeidentifyField.TYPE, DetectStatus.UNKNOWN), extension=extension, word_count=word_count, char_count=char_count, @@ -189,25 +189,25 @@ def output_to_dict_list(output): ) def __get_token_format(self, request): - if not hasattr(request, "token_format") or request.token_format is None: + if not hasattr(request, DeidentifyField.TOKEN_FORMAT) or request.token_format is None: return None return { - 'default': getattr(request.token_format, "default", None), - 'entity_unq_counter': getattr(request.token_format, "entity_unique_counter", None), - 'entity_only': getattr(request.token_format, "entity_only", None), + DeidentifyField.DEFAULT: getattr(request.token_format, DeidentifyField.DEFAULT, None), + DeidentifyField.ENTITY_UNQ_COUNTER: getattr(request.token_format, DeidentifyField.ENTITY_UNIQUE_COUNTER, None), + DeidentifyField.ENTITY_ONLY: getattr(request.token_format, DeidentifyField.ENTITY_ONLY, None), } def __get_transformations(self, request): - if not hasattr(request, "transformations") or request.transformations is None: + if not hasattr(request, DeidentifyField.TRANSFORMATIONS) or request.transformations is None: return None - shift_dates = getattr(request.transformations, "shift_dates", None) + shift_dates = getattr(request.transformations, DeidentifyField.SHIFT_DATES, None) if shift_dates is None: return None return { - 'shift_dates': { - 'max_days': getattr(shift_dates, "max", None), - 'min_days': getattr(shift_dates, "min", None), - 'entity_types': getattr(shift_dates, "entities", None) + DeidentifyField.SHIFT_DATES: { + DeidentifyField.MAX_DAYS: getattr(shift_dates, DeidentifyField.MAX, None), + DeidentifyField.MIN_DAYS: getattr(shift_dates, DeidentifyField.MIN, None), + DeidentifyField.ENTITY_TYPES: getattr(shift_dates, DeidentifyField.ENTITIES, None) } } @@ -223,12 +223,12 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.deidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=deidentify_text_body['text'], - entity_types=deidentify_text_body['entity_types'], - allow_regex=deidentify_text_body['allow_regex'], - restrict_regex=deidentify_text_body['restrict_regex'], - token_type=deidentify_text_body['token_type'], - transformations=deidentify_text_body['transformations'], + text=deidentify_text_body[DeidentifyField.TEXT], + entity_types=deidentify_text_body[DeidentifyField.ENTITY_TYPES], + allow_regex=deidentify_text_body[DeidentifyField.ALLOW_REGEX], + restrict_regex=deidentify_text_body[DeidentifyField.RESTRICT_REGEX], + token_type=deidentify_text_body[DeidentifyField.TOKEN_TYPE], + transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], request_options=self.__get_headers() ) deidentify_text_response = parse_deidentify_text_response(api_response) @@ -251,8 +251,8 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.reidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=reidentify_text_body['text'], - format=reidentify_text_body['format'], + text=reidentify_text_body[DeidentifyField.TEXT], + format=reidentify_text_body[DeidentifyField.FORMAT], request_options=self.__get_headers() ) reidentify_text_response = parse_reidentify_text_response(api_response) @@ -267,11 +267,11 @@ def __get_file_from_request(self, request: DeidentifyFileRequest): file_input = request.file # Check for file - if hasattr(file_input, 'file') and file_input.file is not None: + if hasattr(file_input, FileUploadField.FILE) and file_input.file is not None: return file_input.file # Check for file_path if file is not provided - if hasattr(file_input, 'file_path') and file_input.file_path is not None: + if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: return open(file_input.file_path, 'rb') def deidentify_file(self, request: DeidentifyFileRequest): @@ -280,7 +280,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): self.__initialize() files_api = self.__vault_client.get_detect_file_api().with_raw_response file_obj = self.__get_file_from_request(request) - file_name = getattr(file_obj, 'name', None) + file_name = getattr(file_obj, FileUploadField.NAME, None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) @@ -290,138 +290,138 @@ def deidentify_file(self, request: DeidentifyFileRequest): req_file = FileDataDeidentifyText(base_64=base64_string, data_format=FileExtension.TXT) api_call = files_api.deidentify_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'output_transcription': getattr(request, 'output_transcription', None), - 'output_processed_audio': getattr(request, 'output_processed_audio', None), - 'bleep_gain': getattr(request, 'bleep', None).gain if getattr(request, 'bleep', None) is not None else None, - 'bleep_frequency': getattr(request, 'bleep', None).frequency if getattr(request, 'bleep', None) is not None else None, - 'bleep_start_padding': getattr(request, 'bleep', None).start_padding if getattr(request, 'bleep', None) is not None else None, - 'bleep_stop_padding': getattr(request, 'bleep', None).stop_padding if getattr(request, 'bleep', None) is not None else None, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION: getattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO, None), + DeidentifyField.BLEEP_GAIN: getattr(request, DeidentifyFileRequestField.BLEEP, None).gain if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.BLEEP_FREQUENCY: getattr(request, DeidentifyFileRequestField.BLEEP, None).frequency if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.BLEEP_START_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).start_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.BLEEP_STOP_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).stop_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension == FileExtension.PDF: req_file = FileDataDeidentifyPdf(base_64=base64_string) api_call = files_api.deidentify_pdf api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'max_resolution': getattr(request, 'max_resolution', None), - 'density': getattr(request, 'pixel_density', None), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MAX_RESOLUTION: getattr(request, DeidentifyFileRequestField.MAX_RESOLUTION, None), + DeidentifyFileRequestField.PIXEL_DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: req_file = FileDataDeidentifyImage(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_image api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'masking_method': getattr(request, 'masking_method', None), - 'output_ocr_text': getattr(request, 'output_ocr_text', None), - 'output_processed_image': getattr(request, 'output_processed_image', None), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MASKING_METHOD: getattr(request, DeidentifyFileRequestField.MASKING_METHOD, None), + DeidentifyFileRequestField.OUTPUT_OCR_TEXT: getattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE, None), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: req_file = FileDataDeidentifyPresentation(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_presentation api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: req_file = FileDataDeidentifySpreadsheet(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_spreadsheet api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: req_file = FileDataDeidentifyDocument(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_document api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.JSON, FileExtension.XML]: req_file = FileDataDeidentifyStructuredText(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_structured_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } else: req_file = FileData(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_file api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) api_response = api_call(**api_kwargs) - run_id = getattr(api_response.data, 'run_id', None) + run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) if request.output_directory and processed_response.status == DetectStatus.SUCCESS: @@ -452,7 +452,7 @@ def get_detect_run(self, request: GetDetectRunRequest): request_options=self.__get_headers() ) if response.data.status == DetectStatus.IN_PROGRESS: - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) + parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index a5cd94fd..856a1961 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -8,7 +8,7 @@ from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics -from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter +from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter, FileUploadField from skyflow.utils.enums import RequestMethod from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log @@ -82,7 +82,7 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return (request.file_name, decoded_bytes) elif request.file_object is not None: - if hasattr(request.file_object, "name") and request.file_object.name: + if hasattr(request.file_object, FileUploadField.NAME) and request.file_object.name: file_name = os.path.basename(request.file_object.name) return (file_name, request.file_object) From 5eb3da98f28b72045345c1f465bc736a9f7fc347 Mon Sep 17 00:00:00 2001 From: skyflow-himanshu Date: Mon, 2 Feb 2026 17:50:14 +0530 Subject: [PATCH 05/23] SK-2496: added samples to ignore for linting --- ruff.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ruff.toml b/ruff.toml index 8b0d5278..aea6cce7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,7 +8,8 @@ exclude = [ "venv", "build", "dist", - "tests" + "tests", + "samples" ] line-length = 120 From 449b191b148a8398f4a73e2b227f26b3ee4f9006 Mon Sep 17 00:00:00 2001 From: skyflow-himanshu Date: Tue, 3 Feb 2026 10:52:09 +0000 Subject: [PATCH 06/23] [AUTOMATED] Private Release 2.0.0.dev0+fbfaaa3 --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 09f844d2..c811e4b3 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0' +current_version = '2.0.0.dev0+fbfaaa3' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 0d05fc30..531b7f67 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0' \ No newline at end of file +SDK_VERSION = '2.0.0.dev0+fbfaaa3' \ No newline at end of file From 17399ea96e0a695dda450de631e0104e8976b655 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Thu, 5 Feb 2026 19:28:23 +0530 Subject: [PATCH 07/23] SK-2522: Fix Python SDK v2 issues reported in Bug Bash (#231) * SK-2522: fix identified bugs --- skyflow/utils/_skyflow_messages.py | 14 +- skyflow/utils/_utils.py | 154 +++- skyflow/utils/constants.py | 2 + skyflow/utils/enums/content_types.py | 3 +- skyflow/utils/validations/_validations.py | 21 +- skyflow/vault/client/client.py | 3 +- skyflow/vault/controller/_connections.py | 13 +- tests/utils/test__utils.py | 830 ++++++++++++++++++- tests/utils/validations/test__validations.py | 65 +- tests/vault/client/test__client.py | 86 +- tests/vault/controller/test__connection.py | 184 +++- 11 files changed, 1288 insertions(+), 87 deletions(-) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 21665972..6a31c078 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -42,12 +42,13 @@ class Error(Enum): EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Specify a valid file path." EMPTY_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Specify a valid file path." INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Expected file path to be a string." - INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a string." + INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a valid file path." EMPTY_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid token for {{}} with id {{}}.Specify a valid credentials token." EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token." INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." - EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." + EXPIRED_BEARER_TOKEN = f"{error_prefix} Initialization failed. Bearer token is invalid or expired." + EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string." @@ -118,10 +119,11 @@ class Error(Enum): INVALID_IDS_TYPE = f"{error_prefix} Validation error. 'ids' has a value of type {{}}. Specify 'ids' as list." INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction' has a value of type {{}}. Specify 'redaction' as type Skyflow.RedactionType." - INVALID_COLUMN_NAME = f"{error_prefix} Validation error. 'column' has a value of type {{}}. Specify 'column' as a string." - INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. columnValues key has a value of type {{}}. Specify columnValues key as list." + INVALID_COLUMN_NAME = f"{error_prefix} Validation error. column_name has a value of type {{}}. Specify 'column' as a string." + INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. column_values key has a value of type {{}}. Specify column_values key as list." + INVALID_COLUMN_VALUES = f"{error_prefix} Validation error. column_values key is an empty list. Specify at least one column value when column_name is passed." INVALID_FIELDS_VALUE = f"{error_prefix} Validation error. fields key has a value of type{{}}. Specify fields key as list." - BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"${error_prefix} Validation error. Both offset and limit cannot be present at the same time" + BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"{error_prefix} Validation error. Both offset and limit cannot be present at the same time" INVALID_OFF_SET_VALUE = f"{error_prefix} Validation error. offset key has a value of type {{}}. Specify offset key as integer." INVALID_LIMIT_VALUE = f"{error_prefix} Validation error. limit key has a value of type {{}}. Specify limit key as integer." INVALID_DOWNLOAD_URL_VALUE = f"{error_prefix} Validation error. download_url key has a value of type {{}}. Specify download_url key as boolean." @@ -366,7 +368,7 @@ class ErrorLogs(Enum): SKYFLOW_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id is required." EMPTY_SKYFLOW_ID = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id can not be empty." - COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. ColumnValues are required." + COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. column_values are required." EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Column group can not be null or empty in column values at index %s2." EMPTY_QUERY= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Query can not be empty." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 83c93b0c..567227f7 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -106,27 +106,41 @@ def convert_detected_entity_to_entity_info(detected_entity): def construct_invoke_connection_request(request, connection_url, logger) -> PreparedRequest: url = parse_path_params(connection_url.rstrip('/'), request.path_params) - try: - if isinstance(request.headers, dict): - header = to_lowercase_keys(json.loads( - json.dumps(request.headers))) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + header = None + content_type = None - if not HttpHeader.CONTENT_TYPE.lower() in header: - header[HttpHeader.CONTENT_TYPE_LOWERCASE] = ContentType.JSON.value + if request.headers is not None: + try: + if isinstance(request.headers, dict): + header = to_lowercase_keys(json.loads( + json.dumps(request.headers))) + + content_type = header.get(HttpHeader.CONTENT_TYPE_LOWERCASE) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - try: - if isinstance(request.body, dict): - json_data, files = get_data_from_content_type( - request.body, header[HttpHeader.CONTENT_TYPE_LOWERCASE] - ) - else: + json_data = None + files = {} + + if request.body is not None: + try: + if isinstance(request.body, dict): + json_data, files = get_data_from_content_type( + request.body, content_type + ) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception as e: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) - except Exception as e: - raise SkyflowError( SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + + if files and header and content_type == ContentType.FORMDATA.value: + header.pop(HttpHeader.CONTENT_TYPE_LOWERCASE, None) validate_invoke_connection_params(logger, request.query_params, request.path_params) @@ -176,16 +190,54 @@ def render_key(parents): def get_data_from_content_type(data, content_type): converted_data = data files = {} + if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) elif content_type == ContentType.FORMDATA.value: - converted_data = r_urlencode(list(), dict(), data) - files = {(None, None)} + print("Hello") + converted_data = None + files = {} + for key, value in data.items(): + files[key] = (None, str(value)) elif content_type == ContentType.JSON.value: converted_data = json.dumps(data) + elif content_type == ContentType.XML.value or content_type == 'application/xml' or content_type == 'text/xml': + if isinstance(data, dict): + converted_data = dict_to_xml(data) + else: + converted_data = str(data) + elif content_type == ContentType.HTML.value or content_type == 'text/html': + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) + else: + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) return converted_data, files +def dict_to_xml(data, root_tag='root'): + def build_xml(d, tag='item'): + if isinstance(d, dict): + xml_parts = [f'<{tag}>'] + for key, value in d.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + elif isinstance(d, list): + return ''.join([build_xml(item, tag) for item in d]) + else: + return f'<{tag}>{d}' + + xml_parts = [f'<{root_tag}>'] + for key, value in data.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + def get_metrics(): sdk_name_version = SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION @@ -347,39 +399,50 @@ def parse_invoke_connection_response(api_response: requests.Response): content = api_response.content if isinstance(content, bytes): content = content.decode(EncodingType.UTF_8) + try: api_response.raise_for_status() - try: - data = json.loads(content) - metadata = {} - if HttpHeader.X_REQUEST_ID in api_response.headers: - metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID] + + content_type = api_response.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE, '').lower() + + if ContentTypeConstants.APPLICATION_JSON in content_type or not content_type: + try: + data = json.loads(content) + except json.JSONDecodeError: + data = content + else: + data = content + + metadata = {} + if HttpHeader.X_REQUEST_ID in api_response.headers: + metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID] - return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) - except Exception as e: - raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code) + return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) + except HTTPError: message = SkyflowMessages.Error.API_ERROR.value.format(status_code) + request_id = api_response.headers.get(HttpHeader.X_REQUEST_ID) + try: - error_response = json.loads(content) - request_id = api_response.headers[HttpHeader.X_REQUEST_ID] + error_response = json.loads(content) error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) - status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, status_code) http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS) message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) if error_from_client is not None: - if details is None: details = [] + if details is None: + details = [] error_from_client_bool = error_from_client.lower() == BooleanString.TRUE details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) + except json.JSONDecodeError: - message = SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content) - raise SkyflowError(message, status_code) + raise SkyflowError(content if content else message, status_code, request_id) def parse_deidentify_text_response(api_response: DeidentifyStringResponse): entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities] @@ -397,9 +460,15 @@ def log_and_reject_error(description, status_code, request_id, http_status=None, raise SkyflowError(description, status_code, request_id, grpc_code, http_status, details) def handle_exception(error, logger): - # handle invalid cluster ID error scenario - if (isinstance(error, httpx.ConnectError)): - handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) + if isinstance(error, httpx.ConnectError): + description = SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=logger) + return + + if not hasattr(error, 'headers') or not hasattr(error, 'body') or error.headers is None or error.body is None: + description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=logger) + return request_id = error.headers.get(HttpHeader.X_REQUEST_ID, ErrorDefaults.UNKNOWN_REQUEST_ID) content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) @@ -411,9 +480,9 @@ def handle_exception(error, logger): elif ContentTypeConstants.TEXT_PLAIN in content_type: handle_text_error(error, data, request_id, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) def handle_json_error(err, data, request_id, logger): try: @@ -436,12 +505,9 @@ def handle_json_error(err, data, request_id, logger): def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) -def handle_generic_error(err, request_id, logger): - handle_generic_error(err, request_id, err.status, logger = logger) - -def handle_generic_error(err, request_id, status, logger): +def handle_generic_error_with_status(err, request_id, status, logger): description = SkyflowMessages.Error.GENERIC_API_ERROR.value - log_and_reject_error(description, status, request_id, logger = logger) + log_and_reject_error(description, status, request_id, logger=logger) def encode_column_values(get_request): encoded_column_values = list() diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 62aa4d11..401bffe5 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -35,6 +35,8 @@ class DetectStatus: FAILED = 'FAILED' UNKNOWN = 'UNKNOWN' +class Detect: + WAIT_TIME = 64 class FileExtension: JSON = 'json' diff --git a/skyflow/utils/enums/content_types.py b/skyflow/utils/enums/content_types.py index 362c286a..f2db5b92 100644 --- a/skyflow/utils/enums/content_types.py +++ b/skyflow/utils/enums/content_types.py @@ -5,4 +5,5 @@ class ContentType(Enum): PLAINTEXT = 'text/plain' XML = 'text/xml' URLENCODED = 'application/x-www-form-urlencoded' - FORMDATA = 'multipart/form-data' \ No newline at end of file + FORMDATA = 'multipart/form-data' + HTML = 'text/html' \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 2ac5783c..4e3ead8a 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -9,7 +9,7 @@ from skyflow.utils.constants import ( ApiKey, ResponseField, RequestParameter, FileUploadField, - DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField + DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField, Detect ) from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ @@ -142,8 +142,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non ) if is_expired(credentials.get(CredentialField.TOKEN), logger): raise SkyflowError( - SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) - if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, + SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value + if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value, invalid_input_error_code ) elif CredentialField.API_KEY in credentials: @@ -247,10 +247,8 @@ def validate_connection_config(logger, config): SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if ConfigField.CREDENTIALS not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) - - validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) + if "credentials" in config: + validate_credentials(logger, config.get("credentials"), "connection", connection_id) return True @@ -408,7 +406,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 and request.wait_time > 64: # noqa: PLR2004 + if request.wait_time < 0 or request.wait_time > Detect.WAIT_TIME: raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): @@ -432,9 +430,6 @@ def validate_insert_request(logger, request): if key is None or key == "": log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) - if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format(RequestOperation.INSERT, key), logger = logger) - if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) @@ -592,8 +587,8 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if column_name and not column_values: - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format(RequestOperation.GET), logger = logger) - SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format("GET"), logger = logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUES.value, invalid_input_error_code) if (column_name or column_values) and skyflow_ids: log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index 2d77330e..45234a40 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -1,3 +1,4 @@ +from skyflow.error import SkyflowError from skyflow.generated.rest.client import Skyflow from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages @@ -86,7 +87,7 @@ def get_bearer_token(self, credentials): if is_expired(self.__bearer_token): self.__is_config_updated = True - raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) return self.__bearer_token diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index ca8c7a1d..76dbfaeb 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -6,6 +6,7 @@ from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader +from skyflow.utils import get_credentials class Connection: @@ -13,15 +14,17 @@ def __init__(self, vault_client): self.__vault_client = vault_client def invoke(self, request: InvokeConnectionRequest): - session = requests.Session() - + log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) config = self.__vault_client.get_config() - bearer_token = self.__vault_client.get_bearer_token(config.get("credentials")) - connection_url = config.get("connection_url") - log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) + + credentials = get_credentials(config.get("credentials"), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) + + bearer_token = self.__vault_client.get_bearer_token(credentials) + + session = requests.Session() if not HttpHeader.X_SKYFLOW_AUTHORIZATION_HEADER.lower() in invoke_connection_request.headers: invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6eaacf47..09195b89 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -394,14 +394,18 @@ def test_parse_invoke_connection_response_successful(self, mock_response): @patch("requests.Response") def test_parse_invoke_connection_response_json_decode_error(self, mock_response): - + """Test that non-JSON content in successful response is returned as string.""" mock_response.status_code = 200 mock_response.content = "Non-JSON Content".encode('utf-8') + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() - with self.assertRaises(SkyflowError) as context: - parse_invoke_connection_response(mock_response) + result = parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Non-JSON Content")) + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Non-JSON Content") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) @patch("requests.Response") def test_parse_invoke_connection_response_http_error_with_json_error_message(self, mock_response): @@ -428,7 +432,9 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Internal Server Error")) + self.assertEqual(context.exception.message, "Internal Server Error") + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_json_error(self, mock_log_and_reject_error): @@ -597,3 +603,817 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.text_index.end, 0) self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) + + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_connect_error(self, mock_log_and_reject_error): + """Test handling httpx.ConnectError.""" + import httpx + mock_error = httpx.ConnectError("Connection refused") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value, + None, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_headers_attribute(self, mock_log_and_reject_error): + """Test handling error without headers attribute.""" + mock_error = Exception("Generic error") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Generic error", + SkyflowMessages.ErrorCodes.SERVER_ERROR.value, + None, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_body_attribute(self, mock_log_and_reject_error): + """Test handling error without body attribute.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "12345"} + delattr(mock_error, 'body') + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + self.assertEqual( + mock_log_and_reject_error.call_args[0][1], + SkyflowMessages.ErrorCodes.SERVER_ERROR.value + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_text_plain_error(self, mock_log_and_reject_error): + """Test handling text/plain content type error.""" + mock_error = Mock() + mock_error.headers = { + 'x-request-id': '1234', + 'content-type': 'text/plain' + } + mock_error.body = "Plain text error message" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Plain text error message", + 500, + "1234", + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_error): + """Test handling generic error with unknown content type.""" + mock_error = Mock() + mock_error.headers = { + 'x-request-id': '1234', + 'content-type': 'application/xml' + } + mock_error.body = "XML error" + mock_error.status = 503 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + 503, + "1234", + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_content_type(self, mock_log_and_reject_error): + """Test handling error without content-type header.""" + mock_error = Mock() + mock_error.headers = {'x-request-id': '1234'} + mock_error.body = "Some error" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + 500, + "1234", + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): + """Test handling JSON error when data is a JSON string.""" + error_json_string = json.dumps({ + "error": { + "message": "String JSON error", + "http_code": 422, + "http_status": "Unprocessable Entity", + "grpc_code": 3, + "details": ["validation failed"] + } + }) + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-3" + + handle_json_error(mock_error, error_json_string, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "String JSON error", + 422, + request_id, + "Unprocessable Entity", + 3, + ["validation failed"], + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_invalid_json(self, mock_log_and_reject_error): + """Test handling JSON decode error.""" + invalid_json = "This is not valid JSON" + mock_error = Mock() + mock_error.status = 500 + mock_logger = Mock() + request_id = "test-request-id-4" + + handle_json_error(mock_error, invalid_json, request_id, mock_logger) + + # Should call with INVALID_JSON_RESPONSE error + mock_log_and_reject_error.assert_called_once() + self.assertEqual( + mock_log_and_reject_error.call_args[0][0], + SkyflowMessages.Error.INVALID_JSON_RESPONSE.value + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_missing_error_field(self, mock_log_and_reject_error): + """Test handling JSON error with missing error field.""" + error_dict = { + "message": "Error without error wrapper" + } + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-5" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + # Should use defaults for missing fields + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + # Default message when error field is missing + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + # Default status code + self.assertEqual(args[1], 500) + self.assertEqual(args[2], request_id) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_text_error_with_status(self, mock_log_and_reject_error): + """Test handle_text_error extracts status correctly.""" + mock_error = Mock() + mock_error.status = 404 + mock_logger = Mock() + request_id = "test-request-id-6" + error_data = "Resource not found" + + from skyflow.utils._utils import handle_text_error + handle_text_error(mock_error, error_data, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Resource not found", + 404, + request_id, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_generic_error_with_status(self, mock_log_and_reject_error): + """Test handle_generic_error_with_status.""" + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-7" + status = 503 + + from skyflow.utils._utils import handle_generic_error_with_status + handle_generic_error_with_status(mock_error, request_id, status, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + 503, + request_id, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_none_error(self, mock_log_and_reject_error): + """Test handling None error object.""" + mock_logger = Mock() + + handle_exception(None, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + SkyflowMessages.ErrorCodes.SERVER_ERROR.value, + None, + logger=mock_logger + ) + + #failed + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_error): + """Test handling empty string error.""" + mock_logger = Mock() + mock_error = Mock() + mock_error.headers = None + mock_error.body = None + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + # Should use str(error) or default message + self.assertEqual( + mock_log_and_reject_error.call_args[0][1], + SkyflowMessages.ErrorCodes.SERVER_ERROR.value + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): + """Test handling JSON error when data is bytes.""" + error_dict = { + "error": { + "message": "Bytes error", + "http_code": 401, + "http_status": "Unauthorized" + } + } + error_bytes = json.dumps(error_dict).encode('utf-8') + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-8" + + handle_json_error(mock_error, error_bytes, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Bytes error", + 401, + request_id, + "Unauthorized", + None, + [], + logger=mock_logger + ) + + # Add these new test methods to the TestUtils class: + + def test_construct_invoke_connection_request_with_no_headers(self): + """Test construct_invoke_connection_request when headers are None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param1": "value1"} + mock_connection_request.headers = None + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {"query": "test"} + + connection_url = "https://example.com/{param1}/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Headers should be None when not provided + self.assertIsNone(result.headers.get('Content-Type')) + + def test_construct_invoke_connection_request_with_xml_content_type(self): + """Test construct_invoke_connection_request with XML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/xml"} + mock_connection_request.body = {"root": {"child": "value"}} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers['content-type'], 'application/xml') + # Body should be converted to XML + self.assertIn('', result.body) + self.assertIn('value', result.body) + + def test_construct_invoke_connection_request_with_html_content_type(self): + """Test construct_invoke_connection_request with HTML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "text/html"} + mock_connection_request.body = {"message": "Hello"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers['content-type'], 'text/html') + # Body should be JSON string for HTML + self.assertEqual(result.body, json.dumps({"message": "Hello"})) + + def test_construct_invoke_connection_request_multipart_removes_content_type(self): + """Test that Content-Type is removed for multipart/form-data.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} + mock_connection_request.body = {"field1": "value1", "field2": "value2"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Content-Type should be auto-generated by requests library + self.assertIn('multipart/form-data', result.headers.get('Content-Type', '')) + self.assertIn('boundary=', result.headers.get('Content-Type', '')) + + def test_construct_invoke_connection_request_with_no_body(self): + """Test construct_invoke_connection_request when body is None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertIsNone(result.body) + + def test_get_data_from_content_type_url_encoded(self): + """Test get_data_from_content_type with URL encoded content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key1": "value1", "key2": "value2"} + content_type = ContentType.URLENCODED.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, "key1=value1&key2=value2") + self.assertEqual(files, {}) + + def test_get_data_from_content_type_form_data(self): + """Test get_data_from_content_type with form data content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"field1": "value1", "field2": "value2"} + content_type = ContentType.FORMDATA.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIsNone(converted_data) + self.assertEqual(files["field1"], (None, "value1")) + self.assertEqual(files["field2"], (None, "value2")) + + def test_get_data_from_content_type_json(self): + """Test get_data_from_content_type with JSON content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = ContentType.JSON.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_dict(self): + """Test get_data_from_content_type with XML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"root": {"child": "value"}} + content_type = "application/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIn("", converted_data) + self.assertIn("value", converted_data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_string(self): + """Test get_data_from_content_type with XML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "value" + content_type = "text/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_dict(self): + """Test get_data_from_content_type with HTML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"message": "Hello"} + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_string(self): + """Test get_data_from_content_type with HTML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "Hello" + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_dict(self): + """Test get_data_from_content_type with unknown content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = "application/custom" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_string(self): + """Test get_data_from_content_type with unknown content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "plain text data" + content_type = "text/plain" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_dict_to_xml_simple_dict(self): + """Test dict_to_xml with simple dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"name": "John", "age": "30"} + result = dict_to_xml(data) + + self.assertIn("John", result) + self.assertIn("30", result) + self.assertTrue(result.startswith("")) + self.assertTrue(result.endswith("")) + + def test_dict_to_xml_nested_dict(self): + """Test dict_to_xml with nested dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"person": {"name": "John", "age": "30"}} + result = dict_to_xml(data) + + self.assertIn("", result) + self.assertIn("John", result) + self.assertIn("30", result) + + def test_dict_to_xml_with_list(self): + """Test dict_to_xml with list values.""" + from skyflow.utils._utils import dict_to_xml + + data = {"items": ["item1", "item2", "item3"]} + result = dict_to_xml(data) + + self.assertIn("item1", result) + self.assertIn("item2", result) + self.assertIn("item3", result) + + @patch("requests.Response") + def test_parse_invoke_connection_response_xml_content(self, mock_response): + """Test parsing XML response content.""" + mock_response.status_code = 200 + mock_response.content = b"success" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/xml" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_url_encoded_content(self, mock_response): + """Test parsing URL encoded response content.""" + mock_response.status_code = 200 + mock_response.content = b"card_number=4111111111111111&cvv=123" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/x-www-form-urlencoded" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "card_number=4111111111111111&cvv=123") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_content(self, mock_response): + """Test parsing HTML response content.""" + mock_response.status_code = 200 + mock_response.content = b"Success" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "text/html" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_error(self, mock_response): + """Test parsing HTML error response.""" + html_error = "

Error 500

" + mock_response.status_code = 500 + mock_response.content = html_error.encode('utf-8') + mock_response.headers = { + "x-request-id": "1234", + "content-type": "text/html" + } + mock_response.raise_for_status = Mock(side_effect=HTTPError("500 Error")) + + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + + self.assertEqual(context.exception.message, html_error) + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") + + @patch("requests.Response") + def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, mock_response): + """Test that JSON decode error falls back to returning string content.""" + mock_response.status_code = 200 + mock_response.content = b"Not valid JSON but still success" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/json" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Not valid JSON but still success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_json(self, mock_response): + """Test parsing response with no content-type but valid JSON.""" + mock_response.status_code = 200 + mock_response.content = json.dumps({"success": True}).encode('utf-8') + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, {"success": True}) + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_text(self, mock_response): + """Test parsing response with no content-type and non-JSON content.""" + mock_response.status_code = 200 + mock_response.content = b"Plain text response" + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Plain text response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_bytes_content(self, mock_response): + """Test parsing response with bytes content.""" + mock_response.status_code = 200 + mock_response.content = b"Binary data response" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/octet-stream" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Binary data response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + def test_construct_invoke_connection_request_headers_json_error(self): + """Test exception handling when json.dumps fails for headers.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + + class UnserializableObject: + def __repr__(self): + raise TypeError("Object is not JSON serializable") + + mock_connection_request.headers = {"key": UnserializableObject()} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('json.dumps', side_effect=TypeError("Object is not JSON serializable")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_headers_generic_exception(self): + """Test generic exception handling for headers processing.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/json"} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('skyflow.utils._utils.to_lowercase_keys', side_effect=Exception("Generic error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_processing_exception(self): + """Test exception handling when body processing fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('skyflow.utils._utils.get_data_from_content_type', side_effect=Exception("Body processing error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_json_dumps_exception(self): + """Test exception handling when json.dumps fails in get_data_from_content_type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + + class UnserializableObject: + pass + + mock_connection_request.body = {"key": UnserializableObject()} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_invalid_url_exception(self): + """Test exception handling when requests.Request.prepare() fails with invalid URL.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('requests.Request') as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Invalid URL structure") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_URL.value.format(connection_url) + ) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_prepare_exception(self): + """Test exception handling when prepare() method fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('requests.Request') as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Prepare failed") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_URL.value.format(connection_url) + ) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_not_dict_raises_error(self): + """Test that non-dict body raises SkyflowError which is caught and re-raised.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = "not a dict" # Invalid body type + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + @patch('skyflow.utils._utils.validate_invoke_connection_params') + def test_construct_invoke_connection_request_validation_exception(self, mock_validate): + """Test that validation exceptions are properly propagated.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param": "value"} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {"query": "value"} + + connection_url = "https://example.com/endpoint" + + mock_validate.side_effect = SkyflowError("Validation failed", 400) + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, "Validation failed") + self.assertEqual(context.exception.http_code, 400) diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 48332a55..4f3b5487 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -116,7 +116,7 @@ def test_validate_credentials_with_expired_token(self): with patch('skyflow.service_account.is_expired', return_value=True): with self.assertRaises(SkyflowError) as context: validate_credentials(self.logger, credentials) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value) def test_validate_credentials_empty_credentials(self): credentials = {} @@ -1044,3 +1044,66 @@ def test_validate_detokenize_request_invalid_redaction_type(self): with self.assertRaises(SkyflowError) as context: validate_detokenize_request(self.logger, request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) + + + def test_validate_deidentify_file_request_wait_time_negative(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=-1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_greater_than_64(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=65, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_lower(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=0, + entities=[DetectEntities.SSN] + ) + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_upper(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_float(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=32.5, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_float_out_of_range(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64.1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 565b1e6f..619c15ec 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -1,5 +1,8 @@ import unittest from unittest.mock import patch, MagicMock + +from skyflow.error import SkyflowError +from skyflow.utils import SkyflowMessages from skyflow.vault.client.client import VaultClient CONFIG = { @@ -97,4 +100,85 @@ def test_get_log_level(self): def test_get_logger(self): mock_logger = MagicMock() self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) \ No newline at end of file + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token") + def test_get_bearer_token_expired_token_raises_error(self, mock_generate_bearer_token, mock_is_expired): + """Test that expired token raises SkyflowError.""" + credentials = {"path": "/path/to/credentials.json"} + mock_generate_bearer_token.return_value = ("expired_token", None) + mock_is_expired.return_value = True + + with self.assertRaises(SkyflowError) as context: + self.vault_client.get_bearer_token(credentials) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + self.assertTrue(self.vault_client._VaultClient__is_config_updated) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + def test_get_bearer_token_expired_token_from_creds_string_raises_error(self, mock_generate_bearer_token_from_creds, mock_is_expired): + """Test that expired token from credentials string raises SkyflowError.""" + credentials = {"credentials_string": '{"key": "value"}'} + mock_generate_bearer_token_from_creds.return_value = ("expired_token", None) + mock_is_expired.return_value = True + + with self.assertRaises(SkyflowError) as context: + self.vault_client.get_bearer_token(credentials) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + self.assertTrue(self.vault_client._VaultClient__is_config_updated) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token") + def test_get_bearer_token_reuses_valid_token(self, mock_generate_bearer_token, mock_is_expired): + """Test that valid bearer token is reused.""" + credentials = {"path": "/path/to/credentials.json"} + mock_generate_bearer_token.return_value = ("valid_token", None) + mock_is_expired.return_value = False + + token1 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token1, "valid_token") + mock_generate_bearer_token.assert_called_once() + + token2 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token2, "valid_token") + mock_generate_bearer_token.assert_called_once() + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token") + def test_get_bearer_token_regenerates_after_config_update(self, mock_generate_bearer_token, mock_is_expired): + """Test that bearer token is regenerated after config update.""" + credentials = {"path": "/path/to/credentials.json"} + mock_generate_bearer_token.side_effect = [("first_token", None), ("second_token", None)] + mock_is_expired.return_value = False + + token1 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token1, "first_token") + + self.vault_client.update_config({"new_key": "new_value"}) + + token2 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token2, "second_token") + self.assertEqual(mock_generate_bearer_token.call_count, 2) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_with_credentials_string(self, mock_log_info, mock_generate_bearer_token_from_creds, mock_is_expired): + """Test get_bearer_token with credentials_string.""" + credentials = {"credentials_string": '{"clientID": "test", "clientName": "test"}'} + mock_generate_bearer_token_from_creds.return_value = ("token_from_creds", None) + mock_is_expired.return_value = False + + token = self.vault_client.get_bearer_token(credentials) + + self.assertEqual(token, "token_from_creds") + mock_generate_bearer_token_from_creds.assert_called_once() + mock_log_info.assert_called_with( + SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, + None + ) diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 4ccad1c7..35a13716 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock import requests from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages, parse_invoke_connection_response @@ -30,10 +30,16 @@ def setUp(self): self.mock_vault_client = Mock() self.mock_vault_client.get_config.return_value = VAULT_CONFIG self.mock_vault_client.get_bearer_token.return_value = VALID_BEARER_TOKEN + self.mock_vault_client.get_logger.return_value = Mock() + self.mock_vault_client.get_common_skyflow_credentials.return_value = None self.connection = Connection(self.mock_vault_client) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_success(self, mock_send): + def test_invoke_success(self, mock_send, mock_get_credentials): + # Mock get_credentials to return credentials + mock_get_credentials.return_value = {"api_key": "test_api_key"} + # Mocking successful response mock_response = Mock() mock_response.status_code = SUCCESS_STATUS_CODE @@ -60,9 +66,36 @@ def test_invoke_success(self, mock_send): } self.assertEqual(vars(response), expected_response) self.mock_vault_client.get_bearer_token.assert_called_once() + mock_get_credentials.assert_called_once() + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_invalid_headers(self, mock_send): + def test_invoke_with_x_skyflow_authorization_already_present(self, mock_send, mock_get_credentials): + """Test that X-Skyflow-Authorization is not overwritten if already present in headers.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + custom_auth = "custom_bearer_token" + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers={"x-skyflow-authorization": custom_auth} + ) + + response = self.connection.invoke(request) + + # Verify bearer token from vault_client is NOT used + self.assertIsNotNone(response) + + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_headers(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=VALID_BODY, @@ -75,8 +108,10 @@ def test_invoke_invalid_headers(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) - @patch('requests.Session.send') - def test_invoke_invalid_body(self, mock_send): + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_body(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=INVALID_BODY, @@ -89,11 +124,16 @@ def test_invoke_invalid_body(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_request_error(self, mock_send): + def test_invoke_request_error(self, mock_send, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_response = Mock() mock_response.status_code = FAILURE_STATUS_CODE - mock_response.content = ERROR_RESPONSE_CONTENT + mock_response.content = ERROR_RESPONSE_CONTENT.encode('utf-8') # Convert to bytes + mock_response.headers = {"x-request-id": "test-request-id"} + mock_response.raise_for_status.side_effect = requests.HTTPError("400 Error") mock_send.return_value = mock_response request = InvokeConnectionRequest( @@ -106,9 +146,99 @@ def test_invoke_request_error(self, mock_send): with self.assertRaises(SkyflowError) as context: self.connection.invoke(request) - self.assertEqual(context.exception.message, f'Skyflow Python SDK {SDK_VERSION} Response {ERROR_RESPONSE_CONTENT} is not valid JSON.') - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(ERROR_RESPONSE_CONTENT)) - self.assertEqual(context.exception.http_code, 400) + + self.assertEqual(context.exception.message, ERROR_RESPONSE_CONTENT) + self.assertEqual(context.exception.http_code, FAILURE_STATUS_CODE) + self.assertEqual(context.exception.request_id, "test-request-id") + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_send_exception(self, mock_send, mock_get_credentials): + """Test handling of generic exception from session.send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_send.side_effect = Exception("Network error") + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_skyflow_error_re_raised(self, mock_send, mock_get_credentials): + """Test that SkyflowError is re-raised without wrapping.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + original_error = SkyflowError("Original error", 401) + mock_send.side_effect = original_error + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + # Should be the same original error + self.assertEqual(context.exception.message, "Original error") + self.assertEqual(context.exception.http_code, 401) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_close_called(self, mock_send, mock_get_credentials): + """Test that session.close() is called after send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + with patch('requests.Session.close') as mock_close: + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify close was called + mock_close.assert_called_once() + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.get_metrics') + @patch('requests.Session.send') + def test_invoke_adds_sky_metadata_header(self, mock_send, mock_get_metrics, mock_get_credentials): + """Test that sky-metadata header is added to request.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_get_metrics.return_value = {"sdk_version": SDK_VERSION} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify get_metrics was called + mock_get_metrics.assert_called_once() + self.assertIsNotNone(response) def test_parse_invoke_connection_response_error_from_client(self): mock_response = Mock(spec=requests.Response) @@ -128,3 +258,37 @@ def test_parse_invoke_connection_response_error_from_client(self): self.assertTrue(any(detail.get('error_from_client') == True for detail in exception.details)) self.assertEqual(exception.request_id, '12345') + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.construct_invoke_connection_request') + def test_invoke_construct_request_called(self, mock_construct, mock_get_credentials): + """Test that construct_invoke_connection_request is called with correct parameters.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_prepared_request = Mock(spec=requests.PreparedRequest) + mock_prepared_request.headers = {} + mock_construct.return_value = mock_prepared_request + + with patch('requests.Session.send') as mock_send: + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + self.connection.invoke(request) + + # Verify construct was called with connection_url from config + mock_construct.assert_called_once_with( + request, + VAULT_CONFIG["connection_url"], + self.mock_vault_client.get_logger() + ) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 9bcdfe0a2a53106b37908ad01f3f8508855644b4 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Thu, 5 Feb 2026 13:58:44 +0000 Subject: [PATCH 08/23] [AUTOMATED] Private Release 2.0.0.dev0+17399ea --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c811e4b3..6f75f969 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0.dev0+fbfaaa3' +current_version = '2.0.0.dev0+17399ea' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 531b7f67..40830338 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0.dev0+fbfaaa3' \ No newline at end of file +SDK_VERSION = '2.0.0.dev0+17399ea' \ No newline at end of file From 1c26b2e093c91cebdd55a1013e0882721e3e46c8 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:56:48 +0530 Subject: [PATCH 09/23] SK-2526: Upgrate urllib3 and setuptools libraries (#233) (#234) * SK-2526: upgrate urllib3 and setuptools libraries --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 6f75f969..3e348085 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,8 @@ long_description=open('README.rst').read(), install_requires=[ 'python_dateutil >= 2.5.3', - 'setuptools >= 21.0.0', - 'urllib3 >= 1.25.3, < 2.1.0', + 'setuptools >= 75.3.3', + 'urllib3 >= 1.25.3, <= 2.6.3', 'pydantic >= 2', 'typing-extensions >= 4.7.1', 'DateTime~=5.5', From e33ae928dc665af383cf8f806f19954326bd0375 Mon Sep 17 00:00:00 2001 From: skyflow-bharti <118584001+skyflow-bharti@users.noreply.github.com> Date: Thu, 19 Feb 2026 19:01:58 +0530 Subject: [PATCH 10/23] SK-2548 fix config validation (#235) * SK-2548 update config validation * SK-2548 fix the unit test cases --- skyflow/error/_skyflow_error.py | 1 - skyflow/utils/_skyflow_messages.py | 4 ++-- skyflow/utils/_utils.py | 4 ++-- tests/utils/test__utils.py | 8 ++++---- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index 7b917fae..fca43935 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -15,5 +15,4 @@ def __init__(self, self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value self.details = details self.request_id = request_id - log_error(message, http_code, request_id, grpc_code, http_status, details) super().__init__() \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 6a31c078..ab2da94b 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -16,7 +16,7 @@ class ErrorCodes(Enum): REDACTION_WITH_TOKENS_NOT_SUPPORTED = 400 class Error(Enum): - GENERIC_API_ERROR = f"{error_prefix} Validation error. Invalid configuration. Please add a valid vault configuration." + GENERIC_API_ERROR = f"{error_prefix} API error. Error occurred." EMPTY_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id." INVALID_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id as a string." @@ -283,7 +283,6 @@ class Info(Enum): VALIDATING_FILE_UPLOAD_REQUEST = f"{INFO}: [{error_prefix}] Validating file upload request." FILE_UPLOAD_REQUEST_RESOLVED = f"{INFO}: [{error_prefix}] File upload request resolved." FILE_UPLOAD_SUCCESS = f"{INFO}: [{error_prefix}] File uploaded successfully." - FILE_UPLOAD_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] File upload failed." INVOKE_CONNECTION_TRIGGERED = f"{INFO}: [{error_prefix}] Invoke connection method triggered." VALIDATING_INVOKE_CONNECTION_REQUEST = f"{INFO}: [{error_prefix}] Validating invoke connection request." @@ -351,6 +350,7 @@ class ErrorLogs(Enum): EMPTY_OR_NULL_VALUE_IN_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Value can not be null or empty in tokens for key {{}}." EMPTY_OR_NULL_KEY_IN_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Key can not be null or empty in tokens." MISMATCH_OF_FIELDS_AND_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Keys for values and tokens are not matching." + FILE_UPLOAD_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] File upload failed." EMPTY_IDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Ids can not be empty." EMPTY_OR_NULL_ID_IN_IDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Id can not be null or empty in ids at index {{}}." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 567227f7..5d83cbcc 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -461,7 +461,7 @@ def log_and_reject_error(description, status_code, request_id, http_status=None, def handle_exception(error, logger): if isinstance(error, httpx.ConnectError): - description = SkyflowMessages.Error.GENERIC_API_ERROR.value + description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value log_and_reject_error(description, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=logger) return @@ -506,7 +506,7 @@ def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) def handle_generic_error_with_status(err, request_id, status, logger): - description = SkyflowMessages.Error.GENERIC_API_ERROR.value + description = str(err) if err else SkyflowMessages.Error.GENERIC_API_ERROR.value log_and_reject_error(description, status, request_id, logger=logger) def encode_column_values(get_request): diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 09195b89..a0c9e3b2 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -615,7 +615,7 @@ def test_handle_exception_connect_error(self, mock_log_and_reject_error): handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - SkyflowMessages.Error.GENERIC_API_ERROR.value, + 'Connection refused', SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=mock_logger @@ -688,7 +688,7 @@ def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_er handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - SkyflowMessages.Error.GENERIC_API_ERROR.value, + str(mock_error), 503, "1234", logger=mock_logger @@ -706,7 +706,7 @@ def test_handle_exception_no_content_type(self, mock_log_and_reject_error): handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - SkyflowMessages.Error.GENERIC_API_ERROR.value, + str(mock_error), 500, "1234", logger=mock_logger @@ -812,7 +812,7 @@ def test_handle_generic_error_with_status(self, mock_log_and_reject_error): handle_generic_error_with_status(mock_error, request_id, status, mock_logger) mock_log_and_reject_error.assert_called_once_with( - SkyflowMessages.Error.GENERIC_API_ERROR.value, + str(mock_error), 503, request_id, logger=mock_logger From 87cf3ae0edef8d50795bd184b87f43e8f7fb18a8 Mon Sep 17 00:00:00 2001 From: skyflow-bharti Date: Thu, 19 Feb 2026 13:32:19 +0000 Subject: [PATCH 11/23] [AUTOMATED] Private Release 2.0.0.dev0+e33ae92 --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 3e348085..dbc8864c 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0.dev0+17399ea' +current_version = '2.0.0.dev0+e33ae92' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 40830338..8a811e58 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0.dev0+17399ea' \ No newline at end of file +SDK_VERSION = '2.0.0.dev0+e33ae92' \ No newline at end of file From bbeeeafa1175b787150372cb72259a0a3cd42ca4 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Tue, 24 Mar 2026 18:15:13 +0000 Subject: [PATCH 12/23] [AUTOMATED] Private Release 2.0.0.dev0+f7d26df --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index dbc8864c..aa38463d 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0.dev0+e33ae92' +current_version = '2.0.0.dev0+f7d26df' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 11beae55..bd8e63ec 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0.dev0+e33ae92' +SDK_VERSION = '2.0.0.dev0+f7d26df' From e564e922697aa7d49619dc3831476230d9c0a01d Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Wed, 20 May 2026 15:19:03 +0530 Subject: [PATCH 13/23] =?UTF-8?q?SK-2813:=20Python=20SDK=20v2=20=E2=80=94?= =?UTF-8?q?=20code=20quality,=20security=20hardening,=20and=20message=20fi?= =?UTF-8?q?xes=20(#242)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SK-2833: Add backward-compatible deprecation shims for update_log_level and FileUploadRequest (#244) * SK-2813: Fix and clean up SDK sample files (#243) --- CHANGELOG.md | 53 ++ README.md | 59 +- samples/detect_api/deidentify_file.py | 124 ++-- samples/detect_api/deidentify_file_async.py | 124 ++++ .../signed_token_generation_example.py | 84 ++- .../token_generation_example.py | 7 +- .../token_generation_with_context_example.py | 46 +- samples/vault_api/credentials_options.py | 28 +- samples/vault_api/get_records.py | 15 +- setup.py | 2 +- skyflow/client/skyflow.py | 18 +- skyflow/error/_skyflow_error.py | 7 +- skyflow/service_account/_utils.py | 144 ++-- skyflow/utils/_helpers.py | 2 +- skyflow/utils/_skyflow_messages.py | 23 +- skyflow/utils/_utils.py | 109 +-- skyflow/utils/_version.py | 2 +- skyflow/utils/constants.py | 6 + .../enums/detect_output_transcriptions.py | 3 +- skyflow/utils/validations/_validations.py | 188 +++-- skyflow/vault/client/client.py | 47 +- skyflow/vault/controller/_connections.py | 6 +- skyflow/vault/controller/_detect.py | 102 +-- skyflow/vault/controller/_vault.py | 23 +- skyflow/vault/data/_file_upload_request.py | 28 +- skyflow/vault/data/_get_response.py | 2 +- .../vault/detect/_deidentify_file_response.py | 31 +- .../vault/detect/_deidentify_text_response.py | 12 +- .../vault/detect/_reidentify_text_response.py | 7 +- tests/client/test_skyflow.py | 288 ++++++-- tests/service_account/test__utils.py | 497 +++++++++++-- tests/utils/test__helpers.py | 3 +- tests/utils/test__utils.py | 672 +++++++++--------- tests/utils/validations/test__validations.py | 306 +++++++- tests/vault/client/test__client.py | 396 +++++++---- tests/vault/controller/test__connection.py | 387 +++++++++- tests/vault/controller/test__detect.py | 110 ++- tests/vault/controller/test__vault.py | 55 ++ 38 files changed, 2980 insertions(+), 1036 deletions(-) create mode 100644 samples/detect_api/deidentify_file_async.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f564510..f63ab2d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,59 @@ All notable changes to this project will be documented in this file. +## [2.0.2] - 2026-05-06 +### Added +- Dict context support for Conditional Data Access. + +## [2.0.1] - 2026-04-29 +### Fixed +- Fern client re-initialisation on token refresh. + +## [2.0.0] - 2025-11-11 +### Added +- Multi-vault and multi-connection support via fluent builder (`Skyflow.builder()`). +- New typed request and response classes for all vault operations (`InsertRequest`, `GetRequest`, `UpdateRequest`, `DeleteRequest`, `QueryRequest`, `DetokenizeRequest`, `TokenizeRequest`, `FileUploadRequest`). +- Detect API: `deidentify_text`, `reidentify_text`, `deidentify_file`, and `get_detect_run`. +- File upload support via `vault().upload_file()`. +- Flexible credential types: API key, static bearer token, service account credentials string, credentials file path, and `SKYFLOW_CREDENTIALS` environment variable. +- `SkyflowError` now includes `http_code`, `grpc_code`, `http_status`, `request_id`, and `details` fields. +- `set_log_level()` on the client for runtime log level changes. + +### Changed +- Complete rewrite of the SDK public API. See [docs/migrate_to_v2.md](docs/migrate_to_v2.md) for migration instructions. + +## [1.16.0] - 2025-09-23 +### Fixed +- Remote disconnect error in vault operations. + +## [1.15.8] - 2025-09-30 +### Fixed +- Retry logic when `continue_on_error` is set to `true` in insert. + +## [1.15.7] - 2025-09-23 +### Fixed +- Retry handling for errors in insert method. + +## [1.15.6] - 2025-09-22 +### Fixed +- Added retry logic for transient errors. + +## [1.15.5] - 2025-09-18 +### Fixed +- Remote disconnected errors in vault operations. + +## [1.15.4] - 2025-09-12 +### Fixed +- Retry on exception during vault requests. + +## [1.15.3] - 2025-09-12 +### Fixed +- Retry on exception during vault requests. + +## [1.15.2] - 2025-09-12 +### Fixed +- Retry on connection error in insert method. + ## [1.15.1] - 2023-12-07 ## Fixed - Not receiving tokens when calling Get with options tokens as true. diff --git a/README.md b/README.md index b9d4ca86..23326cca 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Skyflow Python SDK +> **This is the current, recommended version of the Skyflow SDK.** V2.1.0 brings flexible auth, multi-vault support, native data types, and rich error diagnostics. +> +> Migrating from v1? See the **[Migration Guide](https://github.com/skyflowapi/skyflow-python/blob/main/docs/migrate_to_v2.md)** for step-by-step instructions. V1 is in maintenance mode and will reach End of Life on October 31, 2026. + The Skyflow Python SDK is designed to help with integrating Skyflow into a Python backend. ## Table of Contents @@ -703,18 +707,65 @@ options = { Embed context values into a bearer token during generation so you can reference those values in your policies. This enables more flexible access controls, such as tracking end-user identity when making API calls using service accounts, and facilitates using signed data tokens during detokenization. -Generate bearer tokens containing context information using a service account with the context_id identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a context_identifier claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. +Generate bearer tokens containing context information using a service account with the `context_id` identifier. Context information is represented as a JWT claim in a Skyflow-generated bearer token. Tokens generated from such service accounts include a `context_identifier` claim, are valid for 60 minutes, and can be used to make API calls to the Data and Management APIs, depending on the service account's permissions. + +The `ctx` parameter accepts either a **string** or a **dict**: + +**String context** — use when your policy references a single context value: + +```python +options = {'ctx': 'user_12345'} +token, _ = generate_bearer_token(filepath, options) +``` + +**Dict context** — use when your policy needs multiple context values for conditional data access. Each key in the dict maps to a Skyflow CEL policy variable under `request.context.*`: + +```python +options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } +} +token, _ = generate_bearer_token(filepath, options) +``` + +With the dict above, your Skyflow policies can reference `request.context.role`, `request.context.department`, and `request.context.user_id` to make conditional access decisions. + +Dict keys must contain only alphanumeric characters and underscores (`[a-zA-Z0-9_]`). Invalid keys will raise a `SkyflowError`. > [!TIP] -> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) -> See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. +> See the full example in the samples directory: [token_generation_with_context_example.py](samples/service_account/token_generation_with_context_example.py) +> See Skyflow's [context-aware authorization](https://docs.skyflow.com) and [conditional data access](https://docs.skyflow.com) docs for policy variable syntax like `request.context.*`. #### Generate signed data tokens: `generate_signed_data_tokens(filepath, options)` Digitally sign data tokens with a service account's private key to add an extra layer of protection. Skyflow generates data tokens when sensitive data is inserted into the vault. Detokenize signed tokens only by providing the signed data token along with a bearer token generated from the service account's credentials. The service account must have the necessary permissions and context to successfully detokenize the signed data tokens. +The `ctx` parameter on signed data tokens also accepts either a **string** or a **dict**, using the same format as bearer tokens: + +```python +# String context +options = { + 'ctx': 'user_12345', + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} + +# Dict context +options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + }, + 'data_tokens': ['dataToken1', 'dataToken2'], + 'time_to_live': 90, +} +``` + > [!TIP] -> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) +> See the full example in the samples directory: [signed_token_generation_example.py](samples/service_account/signed_token_generation_example.py) > See [docs.skyflow.com](https://docs.skyflow.com) for more details on authentication, access control, and governance for Skyflow. ## Logging diff --git a/samples/detect_api/deidentify_file.py b/samples/detect_api/deidentify_file.py index 99b4b26e..88f012c9 100644 --- a/samples/detect_api/deidentify_file.py +++ b/samples/detect_api/deidentify_file.py @@ -1,7 +1,14 @@ from skyflow.error import SkyflowError from skyflow import Env, Skyflow, LogLevel from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions -from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep, FileInput +from skyflow.vault.detect import ( + DeidentifyFileRequest, + TokenFormat, + Transformations, + DateTransformation, + Bleep, + FileInput, +) """ * Skyflow Deidentify File Example @@ -11,6 +18,7 @@ * spreadsheets, presentations, structured text. """ + def perform_file_deidentification(): try: # Step 1: Configure Credentials @@ -23,7 +31,7 @@ def perform_file_deidentification(): 'vault_id': '', # Replace with your vault ID 'cluster_id': '', # Replace with your cluster ID 'env': Env.PROD, # Deployment environment - 'credentials': credentials + 'credentials': credentials, } # Step 3: Configure & Initialize Skyflow Client @@ -36,70 +44,66 @@ def perform_file_deidentification(): # Step 4: Create File Object file_path = '' # Replace with your file path - file = open(file_path, 'rb') - # Step 5: Configure Deidentify File Request with all options - deidentify_request = DeidentifyFileRequest( - file=FileInput(file), # File to de-identify (can also provide a file path) - entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect - allow_regex_list=[''], # Optional: Patterns to allow - restrict_regex_list=[''], # Optional: Patterns to restrict - - # Token format configuration - token_format=TokenFormat( - vault_token=[DetectEntities.SSN], # Use vault tokens for these entities - ), - - # Optional: Custom transformations - # transformations=Transformations( - # shift_dates=DateTransformation( - # max_days=30, - # min_days=10, - # entities=[DetectEntities.DOB] - # ) - # ), - - # Output configuration - output_directory='', # Where to save processed file - wait_time=15, # Max wait time in seconds (max 64) - - # Image-specific options - output_processed_image=True, # Include processed image in output - output_ocr_text=True, # Include OCR text in response - masking_method=MaskingMethod.BLACKBOX, # Masking method for images - - # PDF-specific options - pixel_density=15, # Pixel density for PDF processing - max_resolution=2000, # Max resolution for PDF - # Audio-specific options - output_processed_audio=True, # Include processed audio - output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type - - # Audio bleep configuration - - # bleep=Bleep( - # gain=5, # Loudness in dB - # frequency=1000, # Pitch in Hz - # start_padding=0.1, # Padding at start (seconds) - # stop_padding=0.2 # Padding at end (seconds) - # ) - ) - - # Step 6: Call deidentifyFile API - response = skyflow_client.detect().deidentify_file(deidentify_request) + # Step 5: Configure Deidentify File Request and call API + with open(file_path, 'rb') as file: + deidentify_request = DeidentifyFileRequest( + file=FileInput(file), # File to de-identify (can also provide a file path) + entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect + allow_regex_list=[''], # Optional: Patterns to allow + restrict_regex_list=[''], # Optional: Patterns to restrict + # Token format configuration + token_format=TokenFormat( + vault_token=[DetectEntities.SSN], # Use vault tokens for these entities + ), + # Optional: Custom transformations + # transformations=Transformations( + # shift_dates=DateTransformation( + # max_days=30, + # min_days=10, + # entities=[DetectEntities.DOB] + # ) + # ), + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) + # Image-specific options + output_processed_image=True, # Include processed image in output + output_ocr_text=True, # Include OCR text in response + masking_method=MaskingMethod.BLACKBOX, # Masking method for images + # PDF-specific options + pixel_density=15, # Pixel density for PDF processing + max_resolution=2000, # Max resolution for PDF + # Audio-specific options + output_processed_audio=True, # Include processed audio + output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type + # Audio bleep configuration + # bleep=Bleep( + # gain=5, # Loudness in dB + # frequency=1000, # Pitch in Hz + # start_padding=0.1, # Padding at start (seconds) + # stop_padding=0.2 # Padding at end (seconds) + # ) + ) + + # Step 6: Call deidentifyFile API + response = skyflow_client.detect().deidentify_file(deidentify_request) # Handle Successful Response - print("\nDeidentify File Response:", response) + print('\nDeidentify File Response:', response) except SkyflowError as error: # Handle Skyflow-specific errors - print('\nSkyflow Error:', { - 'http_code': error.http_code, - 'grpc_code': error.grpc_code, - 'http_status': error.http_status, - 'message': error.message, - 'details': error.details - }) + print( + '\nSkyflow Error:', + { + 'http_code': error.http_code, + 'grpc_code': error.grpc_code, + 'http_status': error.http_status, + 'message': error.message, + 'details': error.details, + }, + ) except Exception as error: # Handle unexpected errors print('Unexpected Error:', error) diff --git a/samples/detect_api/deidentify_file_async.py b/samples/detect_api/deidentify_file_async.py new file mode 100644 index 00000000..7ff2ac13 --- /dev/null +++ b/samples/detect_api/deidentify_file_async.py @@ -0,0 +1,124 @@ +from skyflow.error import SkyflowError +from skyflow import Env, Skyflow, LogLevel +from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions +from skyflow.vault.detect import ( + DeidentifyFileRequest, + TokenFormat, + Transformations, + DateTransformation, + Bleep, + FileInput, +) +from concurrent.futures import ThreadPoolExecutor + +""" + * Skyflow Deidentify File Example + * + * This sample demonstrates how to use all available options for deidentifying files + * using an asynchronous approach. + * Supported file types: images (jpg, png, etc.), pdf, audio (mp3, wav), documents, + * spreadsheets, presentations, structured text. +""" + + +def perform_file_deidentification_async(): + try: + # Step 1: Configure Credentials + credentials = { + 'path': '/path/to/credentials.json' # Path to credentials file + } + + # Step 2: Configure Vault + vault_config = { + 'vault_id': '', # Replace with your vault ID + 'cluster_id': '', # Replace with your cluster ID + 'env': Env.PROD, # Deployment environment + 'credentials': credentials, + } + + # Step 3: Configure & Initialize Skyflow Client + skyflow_client = ( + Skyflow.builder() + .add_vault_config(vault_config) + .set_log_level(LogLevel.INFO) # Use LogLevel.ERROR in production + .build() + ) + + # Step 4: Create File Object + file_path = '' # Replace with your file path + + deidentify_request = DeidentifyFileRequest( + file=FileInput(file_path=file_path), # File to de-identify + # entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect + allow_regex_list=[''], # Optional: Patterns to allow + restrict_regex_list=[''], # Optional: Patterns to restrict + # Token format configuration + token_format=TokenFormat( + vault_token=[DetectEntities.SSN], # Use vault tokens for these entities + ), + # Optional: Custom transformations + # transformations=Transformations( + # shift_dates=DateTransformation( + # max_days=30, + # min_days=10, + # entities=[DetectEntities.DOB] + # ) + # ), + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) + # Image-specific options + output_processed_image=True, # Include processed image in output + output_ocr_text=True, # Include OCR text in response + masking_method=MaskingMethod.BLACKBOX, # Masking method for images + # PDF-specific options + pixel_density=15, # Pixel density for PDF processing + max_resolution=2000, # Max resolution for PDF + # Audio-specific options + output_processed_audio=True, # Include processed audio + output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type + # Audio bleep configuration + # bleep=Bleep( + # gain=5, # Loudness in dB + # frequency=1000, # Pitch in Hz + # start_padding=0.1, # Padding at start (seconds) + # stop_padding=0.2 # Padding at end (seconds) + # ) + ) + + # Create a thread pool executor + executor = ThreadPoolExecutor(max_workers=1) + + future = executor.submit(lambda: skyflow_client.detect().deidentify_file(deidentify_request)) + + def handle_response(future): + exception = future.exception() + if exception is not None: + if isinstance(exception, SkyflowError): + # Handle Skyflow-specific errors + print( + '\nSkyflow Error:', + { + 'http_code': exception.http_code, + 'grpc_code': exception.grpc_code, + 'http_status': exception.http_status, + 'message': exception.message, + 'details': exception.details, + }, + ) + else: + # Handle unexpected errors + print('Unexpected Error:', exception) + return + + # Handle Successful Response + result = future.result() + print('\nDeidentify File Response:', result) + + future.add_done_callback(handle_response) + + executor.shutdown(wait=True) + + except Exception as error: + # Handle unexpected errors + print('Unexpected Error:', error) diff --git a/samples/service_account/signed_token_generation_example.py b/samples/service_account/signed_token_generation_example.py index 32140ada..7ae175cd 100644 --- a/samples/service_account/signed_token_generation_example.py +++ b/samples/service_account/signed_token_generation_example.py @@ -1,12 +1,10 @@ import json from skyflow.service_account import ( - is_expired, generate_signed_data_tokens, generate_signed_data_tokens_from_creds, ) -file_path = 'CREDENTIALS_FILE_PATH' -bearer_token = '' +file_path = '' skyflow_credentials = { 'clientID': '', @@ -18,42 +16,64 @@ credentials_string = json.dumps(skyflow_credentials) -options = { - 'ctx': 'CONTEXT_ID', - 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], - 'time_to_live': 90, # in seconds -} +# Approach 1: Signed data tokens with string context +# Returns: [('', ''), ...] +def get_signed_tokens_with_string_context(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['', ''], + 'time_to_live': 90, # in seconds + } + try: + results = generate_signed_data_tokens(file_path, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results + except Exception as e: + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_file_path(): - # Generate signed bearer token from credentials file path. - global bearer_token +# Approach 2: Signed data tokens with JSON object context (dict) +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "analyst" and request.context.department == "research" +def get_signed_tokens_with_object_context(): + options = { + 'ctx': { + 'role': 'analyst', + 'department': 'research', + 'user_id': 'user_67890', + }, + 'data_tokens': ['', ''], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens(file_path, options) - return data_token, signed_data_token - + results = generate_signed_data_tokens(file_path, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error: {str(e)}') -def get_signed_bearer_token_from_credentials_string(): - # Generate signed bearer token from credentials string. - global bearer_token - +# Approach 3: Signed data tokens from credentials string +def get_signed_tokens_from_credentials_string(): + options = { + 'ctx': 'user_12345', + 'data_tokens': ['', ''], + 'time_to_live': 90, + } try: - if not is_expired(bearer_token): - return bearer_token - else: - data_token, signed_data_token = generate_signed_data_tokens_from_creds(credentials_string, options) - return data_token, signed_data_token - + results = generate_signed_data_tokens_from_creds(credentials_string, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results except Exception as e: - print(f'Error generating token from credentials string: {str(e)}') - + print(f'Error: {str(e)}') -print(get_signed_bearer_token_from_file_path()) -print(get_signed_bearer_token_from_credentials_string()) +print('String context:') +get_signed_tokens_with_string_context() +print('Object context:') +get_signed_tokens_with_object_context() +print('Creds string:') +get_signed_tokens_from_credentials_string() diff --git a/samples/service_account/token_generation_example.py b/samples/service_account/token_generation_example.py index 34db4c37..32fa022b 100644 --- a/samples/service_account/token_generation_example.py +++ b/samples/service_account/token_generation_example.py @@ -5,7 +5,7 @@ is_expired, ) -file_path = 'CREDENTIALS_FILE_PATH' +file_path = '' bearer_token = '' # To generate Bearer Token from credentials string. @@ -46,10 +46,9 @@ def get_bearer_token_from_credentials_string(): bearer_token = token return bearer_token except Exception as e: - print(f"Error generating token from credentials string: {str(e)}") - + print(f'Error generating token from credentials string: {str(e)}') print(get_bearer_token_from_file_path()) -print(get_bearer_token_from_credentials_string()) \ No newline at end of file +print(get_bearer_token_from_credentials_string()) diff --git a/samples/service_account/token_generation_with_context_example.py b/samples/service_account/token_generation_with_context_example.py index a43a072a..03aa9f06 100644 --- a/samples/service_account/token_generation_with_context_example.py +++ b/samples/service_account/token_generation_with_context_example.py @@ -18,11 +18,13 @@ } credentials_string = json.dumps(skyflow_credentials) -options = {'ctx': ''} -def get_bearer_token_with_context_from_file_path(): - # Generate bearer token with context from credentials file path. +# Approach 1: Bearer token with string context +# Use a simple string identifier when your policy references a single context value. +# In your Skyflow policy, reference this as: request.context +def get_bearer_token_with_string_context(): global bearer_token + options = {'ctx': 'user_12345'} try: if not is_expired(bearer_token): @@ -31,14 +33,40 @@ def get_bearer_token_with_context_from_file_path(): token, _ = generate_bearer_token(file_path, options) bearer_token = token return bearer_token + except Exception as e: + print(f'Error generating token: {str(e)}') + + +# Approach 2: Bearer token with JSON object context (dict) +# Use a dict when your policy needs multiple context values for conditional data access. +# Each key maps to a Skyflow CEL policy variable under request.context.* +# For example: request.context.role == "admin" and request.context.department == "finance" +def get_bearer_token_with_object_context(): + global bearer_token + options = { + 'ctx': { + 'role': 'admin', + 'department': 'finance', + 'user_id': 'user_12345', + } + } + try: + if not is_expired(bearer_token): + return bearer_token + else: + token, _ = generate_bearer_token(file_path, options) + bearer_token = token + return bearer_token except Exception as e: - print(f'Error generating token from file path: {str(e)}') + print(f'Error generating token: {str(e)}') +# Approach 3: Bearer token with string context from credentials string def get_bearer_token_with_context_from_credentials_string(): - # Generate bearer token with context from credentials string. global bearer_token + options = {'ctx': 'user_12345'} + try: if not is_expired(bearer_token): return bearer_token @@ -47,9 +75,9 @@ def get_bearer_token_with_context_from_credentials_string(): bearer_token = token return bearer_token except Exception as e: - print(f"Error generating token from credentials string: {str(e)}") - + print(f"Error generating token: {str(e)}") -print(get_bearer_token_with_context_from_file_path()) -print(get_bearer_token_with_context_from_credentials_string()) \ No newline at end of file +print("String context:", get_bearer_token_with_string_context()) +print("Object context:", get_bearer_token_with_object_context()) +print("Creds string:", get_bearer_token_with_context_from_credentials_string()) diff --git a/samples/vault_api/credentials_options.py b/samples/vault_api/credentials_options.py index db792042..2155f99d 100644 --- a/samples/vault_api/credentials_options.py +++ b/samples/vault_api/credentials_options.py @@ -13,6 +13,7 @@ 4. Handle response and errors """ + def perform_secure_data_deletion(): try: # Step 1: Configure Bearer Token Credentials @@ -31,10 +32,10 @@ def perform_secure_data_deletion(): } secondary_vault_config = { - 'vault_id': 'YOUR_SECONDARY_VAULT_ID', # Secondary vault - 'cluster_id': 'YOUR_SECONDARY_CLUSTER_ID', # Cluster ID from your vault URL + 'vault_id': '', # Secondary vault + 'cluster_id': '', # Cluster ID from your vault URL 'env': Env.PROD, # Deployment environment - 'credentials': credentials + 'credentials': credentials, } # Step 3: Configure & Initialize Skyflow Client @@ -51,13 +52,10 @@ def perform_secure_data_deletion(): primary_table_name = '' # Replace with actual table name - primary_delete_request = DeleteRequest( - table=primary_table_name, - ids=primary_delete_ids - ) + primary_delete_request = DeleteRequest(table=primary_table_name, ids=primary_delete_ids) # Perform Delete Operation for Primary Vault - primary_delete_response = skyflow_client.vault('').delete(primary_delete_request) + primary_delete_response = skyflow_client.vault('').delete(primary_delete_request) # Handle Successful Response print('Primary Vault Deletion Successful:', primary_delete_response) @@ -67,10 +65,7 @@ def perform_secure_data_deletion(): secondary_table_name = '' # Replace with actual table name - secondary_delete_request = DeleteRequest( - table=secondary_table_name, - ids=secondary_delete_ids - ) + secondary_delete_request = DeleteRequest(table=secondary_table_name, ids=secondary_delete_ids) # Perform Delete Operation for Secondary Vault secondary_delete_response = skyflow_client.vault('').delete(secondary_delete_request) @@ -78,17 +73,12 @@ def perform_secure_data_deletion(): # Handle Successful Response print('Secondary Vault Deletion Successful:', secondary_delete_response) - except SkyflowError as error: # Comprehensive Error Handling - print('Skyflow Specific Error: ', { - 'code': error.http_code, - 'message': error.message, - 'details': error.details - }) + print('Skyflow Specific Error: ', {'code': error.http_code, 'message': error.message, 'details': error.details}) except Exception as error: print('Unexpected Error:', error) # Invoke the secure data deletion function -perform_secure_data_deletion() \ No newline at end of file +perform_secure_data_deletion() diff --git a/samples/vault_api/get_records.py b/samples/vault_api/get_records.py index b2fd445f..9e4d031a 100644 --- a/samples/vault_api/get_records.py +++ b/samples/vault_api/get_records.py @@ -4,6 +4,7 @@ from skyflow import Skyflow, LogLevel from skyflow.vault.data import GetRequest + def perform_secure_data_retrieval(): try: # Step 1: Configure Credentials @@ -28,7 +29,7 @@ def perform_secure_data_retrieval(): 'vault_id': '', # primary vault 'cluster_id': '', # Cluster ID from your vault URL 'env': Env.PROD, # Deployment environment (PROD by default) - 'credentials': credentials # Authentication method + 'credentials': credentials, # Authentication method } # Step 3: Configure & Initialize Skyflow Client @@ -42,10 +43,10 @@ def perform_secure_data_retrieval(): # Step 4: Prepare Retrieval Data - get_ids = ['', 'SKYFLOW_ID2'] + get_ids = ['', ''] get_request = GetRequest( - table='', # Replace with your actual table name + table='', # Replace with your actual table name ids=get_ids, ) @@ -57,15 +58,11 @@ def perform_secure_data_retrieval(): except SkyflowError as error: # Comprehensive Error Handling - print('Skyflow Specific Error: ', { - 'code': error.http_code, - 'message': error.message, - 'details': error.details - }) + print('Skyflow Specific Error: ', {'code': error.http_code, 'message': error.message, 'details': error.details}) except Exception as error: print('Unexpected Error:', error) # Invoke the secure data retrieval function -perform_secure_data_retrieval() \ No newline at end of file +perform_secure_data_retrieval() diff --git a/setup.py b/setup.py index aa38463d..8f76225e 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0.dev0+f7d26df' +current_version = '2.0.2' setup( name='skyflow', diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 0bfde34e..2255ee50 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -1,4 +1,6 @@ +import warnings from collections import OrderedDict +from typing_extensions import deprecated from skyflow import LogLevel from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages @@ -59,12 +61,18 @@ def set_log_level(self, log_level): self.__builder._Builder__set_log_level(log_level) return self + @deprecated("[DEPRECATED] Use set_log_level() instead.") + def update_log_level(self, log_level): + warnings.warn( + SkyflowMessages.Warning.UPDATE_LOG_LEVEL_DEPRECATED.value, + DeprecationWarning, + stacklevel=2, + ) + return self.set_log_level(log_level) + def get_log_level(self): return self.__builder._Builder__log_level - def update_log_level(self, log_level): - self.__builder._Builder__set_log_level(log_level) - def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) return vault_config.get(OptionField.VAULT_CONTROLLER) @@ -114,6 +122,8 @@ def remove_vault_config(self, vault_id): def update_vault_config(self, config): validate_update_vault_config(self.__logger, config) vault_id = config.get(OptionField.VAULT_ID) + if vault_id not in self.__vault_configs: + raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value) vault_config = self.__vault_configs[vault_id] vault_config.get(OptionField.VAULT_CLIENT).update_config(config) @@ -155,6 +165,8 @@ def remove_connection_config(self, connection_id): def update_connection_config(self, config): validate_update_connection_config(self.__logger, config) connection_id = config[OptionField.CONNECTION_ID] + if connection_id not in self.__connection_configs: + raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value) connection_config = self.__connection_configs[connection_id] connection_config.get(OptionField.VAULT_CLIENT).update_config(config) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index fca43935..bf472177 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -1,5 +1,4 @@ from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_error class SkyflowError(Exception): def __init__(self, @@ -8,11 +7,11 @@ def __init__(self, request_id = None, grpc_code = None, http_status = None, - details = []): + details = None): self.message = message self.http_code = http_code self.grpc_code = grpc_code self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value - self.details = details + self.details = details if details else None self.request_id = request_id - super().__init__() \ No newline at end of file + super().__init__(message) \ No newline at end of file diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index f4c98faf..deccf973 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -1,5 +1,6 @@ import json import datetime +import re import time import jwt from urllib.parse import urlparse @@ -10,11 +11,56 @@ from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError from skyflow.utils import is_valid_url +from skyflow.utils.constants import CTX_KEY_REGEX invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value +_CTX_KEY_PATTERN = re.compile(CTX_KEY_REGEX) + +_SNAKE_TO_CAMEL_CRED_MAP = { + 'private_key': CredentialField.PRIVATE_KEY, + 'client_id': CredentialField.CLIENT_ID, + 'key_id': CredentialField.KEY_ID, + 'token_uri': CredentialField.TOKEN_URI, + 'client_name': CredentialField.CLIENT_NAME, +} + + +def _normalize_credentials(credentials): + return {_SNAKE_TO_CAMEL_CRED_MAP.get(k, k): v for k, v in credentials.items()} + + +def _validate_and_resolve_ctx(ctx): + """Validate ctx value and return resolved value for JWT claims. + Returns None if ctx should be omitted, the value if valid, or raises SkyflowError if invalid. + """ + if ctx is None: + return None + if isinstance(ctx, str): + if ctx.strip() == '': + return None + return ctx + if isinstance(ctx, dict): + if len(ctx) == 0: + return None + for key in ctx: + if not isinstance(key, str) or not _CTX_KEY_PATTERN.match(key): + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_MAP_KEY.value.format(key), + invalid_input_error_code + ) + return ctx + if isinstance(ctx, (bool, int, float)): + return ctx + raise SkyflowError( + SkyflowMessages.Error.INVALID_CTX_TYPE.value, + invalid_input_error_code + ) + def is_expired(token, logger = None): + if token is None: + return True if len(token) == 0: log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -34,20 +80,18 @@ def is_expired(token, logger = None): return True def generate_bearer_token(credentials_file_path, options = None, logger = None): + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) try: - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) - credentials_file =open(credentials_file_path, 'r') + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) + except SkyflowError: + raise except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger) - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) - - finally: - credentials_file.close() result = get_service_account_token(credentials, options, logger) return result @@ -62,24 +106,25 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) return result def get_service_account_token(credentials, options, logger): + credentials = _normalize_credentials(credentials) try: private_key = credentials[CredentialField.PRIVATE_KEY] - except: - log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) + except KeyError: + log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: client_id = credentials[CredentialField.CLIENT_ID] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: key_id = credentials[CredentialField.KEY_ID] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: token_uri = credentials[CredentialField.TOKEN_URI] - except: + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) @@ -87,9 +132,12 @@ def get_service_account_token(credentials, options, logger): log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) - if options and "token_uri" in options: - token_uri = options["token_uri"] - + if options and CredentialField.TOKEN_URI_OPTION in options: + token_uri = options[CredentialField.TOKEN_URI_OPTION] + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) auth_client = AuthClient(base_url) @@ -101,7 +149,7 @@ def get_service_account_token(credentials, options, logger): try: response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) except UnauthorizedError: @@ -120,8 +168,10 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): JwtField.SUB: client_id, JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and JwtField.CTX in options: - payload[JwtField.CTX] = options.get(JwtField.CTX) + if options and OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options.get(OptionField.CTX)) + if resolved_ctx is not None: + payload[JwtField.CTX] = resolved_ctx try: return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: @@ -130,18 +180,21 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): + options = options if options is not None else {} + credentials_obj = _normalize_credentials(credentials_obj) expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) prefix = JWT.SIGNED_TOKEN_PREFIX - token_uri = credentials_obj.get("tokenURI") + token_uri = credentials_obj.get(CredentialField.TOKEN_URI) if not isinstance(token_uri, str) or not is_valid_url(token_uri): log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) - - if options and "token_uri" in options: - token_uri = options["token_uri"] + resolved_ctx = None + if OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options[OptionField.CTX]) + results = [] if options and options.get(OptionField.DATA_TOKENS): for token in options[OptionField.DATA_TOKENS]: claims = { @@ -152,37 +205,31 @@ def get_signed_tokens(credentials_obj, options): JwtField.TOK: token, JwtField.IAT: int(time.time()), } - - if JwtField.CTX in options: - claims[JwtField.CTX] = options[JwtField.CTX] - + if resolved_ctx is not None: + claims[JwtField.CTX] = resolved_ctx private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) - try: + try: signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) - - response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) + results.append(get_signed_data_token_response_object(prefix + signed_jwt, token)) log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) - return response_object + return results def generate_signed_data_tokens(credentials_file_path, options): log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value) try: - credentials_file =open(credentials_file_path, 'r') + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), + invalid_input_error_code) + except SkyflowError: + raise except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), - invalid_input_error_code) - - finally: - credentials_file.close() - return get_signed_tokens(credentials, options) def generate_signed_data_tokens_from_creds(credentials, options): @@ -195,9 +242,6 @@ def generate_signed_data_tokens_from_creds(credentials, options): raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value, invalid_input_error_code) return get_signed_tokens(json_credentials, options) + def get_signed_data_token_response_object(signed_token, actual_token): - response_object = { - ResponseField.TOKEN: actual_token, - ResponseField.SIGNED_TOKEN: signed_token - } - return response_object.get(ResponseField.TOKEN), response_object.get(ResponseField.SIGNED_TOKEN) + return actual_token, signed_token diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py index 090f3a2b..12ff1257 100644 --- a/skyflow/utils/_helpers.py +++ b/skyflow/utils/_helpers.py @@ -13,6 +13,6 @@ def format_scope(scopes): def is_valid_url(url): try: result = urlparse(url) - return all([result.scheme in ("http", "https"), result.netloc]) + return all([result.scheme == "https", result.netloc]) except Exception: return False \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 989aa298..01e15579 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -61,6 +61,8 @@ class Error(Enum): EMPTY_CONTEXT = f"{error_prefix} Initialization failed. Invalid context provided. Specify context as type Context." INVALID_CONTEXT_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid context for {{}} with id {{}}. Specify a valid context." INVALID_CONTEXT = f"{error_prefix} Initialization failed. Invalid context. Specify a valid context." + INVALID_CTX_TYPE = f"{error_prefix} Initialization failed. Invalid ctx type. Specify ctx as a string or a dict." + INVALID_CTX_MAP_KEY = f"{error_prefix} Initialization failed. Invalid key '{{}}' in ctx dict. Keys must contain only alphanumeric characters and underscores." INVALID_LOG_LEVEL = f"{error_prefix} Initialization failed. Invalid log level. Specify a valid log level." EMPTY_LOG_LEVEL = f"{error_prefix} Initialization failed. Specify a valid log level." @@ -88,14 +90,15 @@ class Error(Enum): INVALID_TABLE_NAME_IN_INSERT = f"{error_prefix} Validation error. Invalid table name in insert request. Specify a valid table name." INVALID_TYPE_OF_DATA_IN_INSERT = f"{error_prefix} Validation error. Invalid type of data in insert request. Specify data as a object array." EMPTY_DATA_IN_INSERT = f"{error_prefix} Validation error. Data array cannot be empty. Specify data in insert request." - INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. 'upsert' key cannot be empty in options. At least one object of table and column is required." + INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. Invalid 'upsert' value in options. Specify 'upsert' as a non-empty string containing the column name." INVALID_HOMOGENEOUS_TYPE = f"{error_prefix} Validation error. Invalid type of homogeneous. Specify homogeneous as a string." INVALID_TOKEN_MODE_TYPE = f"{error_prefix} Validation error. Invalid type of token mode. Specify token mode as a TokenMode enum." INVALID_RETURN_TOKENS_TYPE = f"{error_prefix} Validation error. Invalid type of return tokens. Specify return tokens as a boolean." INVALID_CONTINUE_ON_ERROR_TYPE = f"{error_prefix} Validation error. Invalid type of continue on error. Specify continue on error as a boolean." TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE = f"{error_prefix} Validation error. 'token_mode' wasn't specified. Set 'token_mode' to 'ENABLE' to insert tokens." INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT = f"{error_prefix} Validation error. 'token_mode' is set to 'ENABLE_STRICT', but some fields are missing tokens. Specify tokens for all fields." - NO_TOKENS_IN_INSERT = f"{error_prefix} Validation error. Tokens weren't specified for records while 'token_strict' was {{}}. Specify tokens." + MISMATCH_OF_FIELDS_AND_TOKENS = f"{error_prefix} Validation error. Keys for values and tokens are not matching. Ensure each values entry and its corresponding tokens entry have the same keys." + NO_TOKENS_IN_INSERT = f"{error_prefix} Validation error. Tokens weren't specified for records while 'token_mode' was {{}}. Specify tokens." BATCH_INSERT_FAILURE = f"{error_prefix} Insert operation failed." GET_FAILURE = f"{error_prefix} Get operation failed." HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT = f"{error_prefix} Validation error. Homogenous is not supported when upsert is passed." @@ -315,6 +318,8 @@ class Info(Enum): DETECT_REQUEST_RESOLVED = f"{INFO}: [{error_prefix}] Detect request is resolved." class ErrorLogs(Enum): + INVALID_LOG_LEVEL = f"{ERROR}: [{error_prefix}] Invalid log level. Specify a valid log level." + INVALID_KEY = f"{ERROR}: [{error_prefix}] Invalid key {{}} in config." VAULTID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid vault config. Vault ID is required." EMPTY_VAULTID = f"{ERROR}: [{error_prefix}] Invalid vault config. Vault ID can not be empty." CLUSTER_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid vault config. Cluster ID is required." @@ -361,8 +366,8 @@ class ErrorLogs(Enum): EMPTY_OR_NULL_ID_IN_IDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Id can not be null or empty in ids at index {{}}." TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokenization is not supported when redaction is applied." TOKENIZATION_SUPPORTED_ONLY_WITH_IDS=f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokenization is not supported when column name and values are passed." - TOKENS_NOT_ALLOWED_WITH_BYOT_DISABLE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are not allowed when token_strict is DISABLE." - INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT =f"{ERROR}: [{error_prefix}] Invalid {{}} request. For tokenStrict as ENABLE_STRICT, tokens should be passed for all fields." + TOKENS_NOT_ALLOWED_WITH_BYOT_DISABLE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are not allowed when token_mode is DISABLE." + INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT =f"{ERROR}: [{error_prefix}] Invalid {{}} request. For token_mode as ENABLE_STRICT, tokens should be passed for all fields." TOKENS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are required." EMPTY_FIELDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Fields can not be empty." EMPTY_OFFSET = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Offset ca not be empty." @@ -411,7 +416,15 @@ class HttpStatus(Enum): BAD_REQUEST = "Bad Request" class Warning(Enum): - WARNING_MESSAGE = "WARNING MESSAGE" + UPDATE_LOG_LEVEL_DEPRECATED = ( + "[DEPRECATED] Skyflow.update_log_level() is deprecated. " + "Use Skyflow.set_log_level() instead — identical behavior." + ) + FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED = ( + "[DEPRECATED] FileUploadRequest: argument order changed. " + "Old positional order: (table, skyflow_id, column_name). " + "New order: FileUploadRequest(table, column_name=..., skyflow_id=...)." + ) diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 5d83cbcc..e3b8eea9 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -32,26 +32,18 @@ invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): - dotenv.load_dotenv() + if config_level_creds is not None: + return config_level_creds + if common_skyflow_creds is not None: + return common_skyflow_creds dotenv_path = dotenv.find_dotenv(usecwd=True) if dotenv_path: load_dotenv(dotenv_path) env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") - if config_level_creds: - return config_level_creds - if common_skyflow_creds: - return common_skyflow_creds if env_skyflow_credentials: - env_skyflow_credentials.strip() - try: - env_creds = env_skyflow_credentials.replace('\n', '\\n') - return { - CredentialField.CREDENTIALS_STRING: env_creds - } - except json.JSONDecodeError: - raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + env_creds = env_skyflow_credentials.strip().replace('\n', '\\n') + return {'credentials_string': env_creds} + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: if len(api_key) != ApiKey.LENGTH: @@ -80,9 +72,9 @@ def parse_path_params(url, path_params): return result -def to_lowercase_keys(dict): +def to_lowercase_keys(data): result = {} - for key, value in dict.items(): + for key, value in data.items(): result[key.lower()] = value return result @@ -136,7 +128,7 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) except SkyflowError: raise - except Exception as e: + except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) if files and header and content_type == ContentType.FORMDATA.value: @@ -194,7 +186,6 @@ def get_data_from_content_type(data, content_type): if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) elif content_type == ContentType.FORMDATA.value: - print("Hello") converted_data = None files = {} for key, value in data.items(): @@ -239,8 +230,11 @@ def build_xml(d, tag='item'): return ''.join(xml_parts) +_CACHED_METRICS: dict = {} + def get_metrics(): - sdk_name_version = SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION + if _CACHED_METRICS: + return _CACHED_METRICS try: sdk_client_device_model = platform.node() @@ -257,13 +251,13 @@ def get_metrics(): except Exception: sdk_runtime_details = "" - details_dic = { - SdkMetricsKey.SDK_NAME_VERSION: sdk_name_version, + _CACHED_METRICS.update({ + SdkMetricsKey.SDK_NAME_VERSION: SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION, SdkMetricsKey.SDK_CLIENT_DEVICE_MODEL: sdk_client_device_model, SdkMetricsKey.SDK_CLIENT_OS_DETAILS: sdk_client_os_details, SdkMetricsKey.SDK_RUNTIME_DETAILS: SdkPrefix.PYTHON_RUNTIME + sdk_runtime_details, - } - return details_dic + }) + return _CACHED_METRICS def parse_insert_response(api_response, continue_on_error): # Retrieve the headers and data from the API response @@ -427,22 +421,30 @@ def parse_invoke_connection_response(api_response: requests.Response): error_response = json.loads(content) error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) - status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, status_code) - http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) - grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) - details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS) - message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) - + http_status = None + grpc_code = None + details = None + + error_obj = error_response.get(ResponseField.ERROR) if isinstance(error_response, dict) else None + if isinstance(error_obj, dict): + status_code = error_obj.get(ResponseField.HTTP_CODE, status_code) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS) + message = error_obj.get(ResponseField.MESSAGE, message) + elif isinstance(error_obj, str) and error_obj: + message = error_obj + if error_from_client is not None: - if details is None: + if details is None: details = [] error_from_client_bool = error_from_client.lower() == BooleanString.TRUE details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) - + except json.JSONDecodeError: - raise SkyflowError(content if content else message, status_code, request_id) + raise SkyflowError(message, status_code, request_id) def parse_deidentify_text_response(api_response: DeidentifyStringResponse): entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities] @@ -486,21 +488,46 @@ def handle_exception(error, logger): def handle_json_error(err, data, request_id, logger): try: - if isinstance(data, dict): # If data is already a dict + if isinstance(data, dict): description = data elif isinstance(data, ErrorResponse): description = data.dict() else: description = json.loads(data) - status_code = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found - http_status = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) - grpc_code = description.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) - details = description.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS, []) - description_message = description.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) - log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger = logger) + if ResponseField.ERROR in description: + error_obj = description.get(ResponseField.ERROR, {}) + status_code = error_obj.get(ResponseField.HTTP_CODE, HttpStatusCode.INTERNAL_SERVER_ERROR) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS, []) + description_message = error_obj.get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + elif ResponseField.RESPONSES in description: + responses = description.get(ResponseField.RESPONSES, []) + messages = [] + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + for resp in responses: + resp_status = resp.get(ResponseField.STATUS, HttpStatusCode.INTERNAL_SERVER_ERROR) + resp_body = resp.get(ResponseField.BODY, {}) + if isinstance(resp_status, int) and resp_status >= HttpStatusCode.BAD_REQUEST: + status_code = resp_status + error_msg = resp_body.get(ResponseField.ERROR) + if error_msg: + messages.append(str(error_msg)) + description_message = '; '.join(messages) if messages else SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + http_status = None + grpc_code = None + details = [] + else: + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + http_status = None + grpc_code = None + details = [] + description_message = SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + + log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger=logger) except json.JSONDecodeError: - log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger = logger) + log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger=logger) def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index bd8e63ec..bc50f210 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0.dev0+f7d26df' +SDK_VERSION = '2.0.2' diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 401bffe5..17ba96e2 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -1,6 +1,7 @@ OPTIONAL_TOKEN='token' PROTOCOL='https' SKY_META_DATA_HEADER='sky-metadata' +CTX_KEY_REGEX=r'^[a-zA-Z0-9_]+$' class SKYFLOW: SKYFLOW_ID = 'skyflowId' @@ -116,6 +117,7 @@ class ResponseField: TYPE = 'type' TOKENIZED_DATA = 'tokenized_data' SIGNED_TOKEN = 'signed_token' + RESPONSES = 'responses' class CredentialField: @@ -123,6 +125,8 @@ class CredentialField: CLIENT_ID = 'clientID' KEY_ID = 'keyID' TOKEN_URI = 'tokenURI' + TOKEN_URI_OPTION = 'token_uri' + CLIENT_NAME = 'clientName' CREDENTIALS_STRING = 'credentials_string' API_KEY = 'api_key' TOKEN = 'token' @@ -192,6 +196,7 @@ class DeidentifyFileRequestField: OUTPUT_OCR_TEXT = 'output_ocr_text' MASKING_METHOD = 'masking_method' PIXEL_DENSITY = 'pixel_density' + DENSITY = 'density' MAX_RESOLUTION = 'max_resolution' OUTPUT_PROCESSED_AUDIO = 'output_processed_audio' OUTPUT_TRANSCRIPTION = 'output_transcription' @@ -227,6 +232,7 @@ class DeidentifyField: ENTITY_UNQ_COUNTER = 'entity_unq_counter' ENTITY_UNIQUE_COUNTER = 'entity_unique_counter' ENTITY_ONLY = 'entity_only' + VAULT_TOKEN = 'vault_token' ENTITIES = 'entities' MAX_DAYS = 'max_days' MIN_DAYS = 'min_days' diff --git a/skyflow/utils/enums/detect_output_transcriptions.py b/skyflow/utils/enums/detect_output_transcriptions.py index 4e14f911..a398a3d8 100644 --- a/skyflow/utils/enums/detect_output_transcriptions.py +++ b/skyflow/utils/enums/detect_output_transcriptions.py @@ -4,4 +4,5 @@ class DetectOutputTranscriptions(Enum): DIARIZED_TRANSCRIPTION = "diarized_transcription" MEDICAL_DIARIZED_TRANSCRIPTION = "medical_diarized_transcription" MEDICAL_TRANSCRIPTION = "medical_transcription" - TRANSCRIPTION = "transcription" \ No newline at end of file + TRANSCRIPTION = "transcription" + PLAINTEXT_TRANSCRIPTION = "plaintext_transcription" \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 08d4905b..6cc2c811 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -42,32 +42,32 @@ def validate_required_field(logger, config, field_name, expected_type, empty_err if field_name not in config or not isinstance(field_value, expected_type): if field_name == ConfigField.VAULT_ID: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value, logger) if field_name == ConfigField.CLUSTER_ID: - logger.error(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value, logger) if field_name == OptionField.CONNECTION_ID: - logger.error(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value, logger) if field_name == OptionField.CONNECTION_URL: - logger.error(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value) + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value, logger) raise SkyflowError(invalid_error, invalid_input_error_code) if isinstance(field_value, str) and not field_value.strip(): if field_name == ConfigField.VAULT_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value, logger) if field_name == ConfigField.CLUSTER_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value, logger) if field_name == OptionField.CONNECTION_ID: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value, logger) if field_name == OptionField.CONNECTION_URL: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value, logger) if field_name == CredentialField.PATH: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value, logger) if field_name == CredentialField.CREDENTIALS_STRING: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value, logger) if field_name == CredentialField.TOKEN: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value, logger) if field_name == CredentialField.API_KEY: - logger.error(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value, logger) raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: @@ -90,6 +90,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) elif len(key_present) > 1: error_message = ( @@ -97,6 +98,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) if CredentialField.ROLES in credentials: @@ -142,6 +144,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) if is_expired(credentials.get(CredentialField.TOKEN), logger): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value, logger) raise SkyflowError( SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value, @@ -160,25 +163,25 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) - if "token_uri" in credentials: - token_uri = credentials.get("token_uri") + if CredentialField.TOKEN_URI_OPTION in credentials: + token_uri = credentials.get(CredentialField.TOKEN_URI_OPTION) if ( token_uri is None or not isinstance(token_uri, str) or not is_valid_url(token_uri) ): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): - raise SkyflowError( SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) - - if log_level is None: - raise SkyflowError(SkyflowMessages.Error.EMPTY_LOG_LEVEL.value, invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.INVALID_LOG_LEVEL.value, logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) def validate_keys(logger, config, config_keys): for key in config.keys(): if key not in config_keys: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_KEY.value.format(key), logger) raise SkyflowError(SkyflowMessages.Error.INVALID_KEY.value.format(key), invalid_input_error_code) def validate_vault_config(logger, config): @@ -208,7 +211,7 @@ def validate_vault_config(logger, config): # Validate env (optional, should be one of LogLevel values) if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + log_error_log(SkyflowMessages.ErrorLogs.ENV_IS_REQUIRED.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) return True @@ -232,8 +235,10 @@ def validate_update_vault_config(logger, config): if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if ConfigField.CREDENTIALS in config and config.get(ConfigField.CREDENTIALS): - validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) + + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) return True @@ -255,8 +260,10 @@ def validate_connection_config(logger, config): SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" in config: - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) + + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) return True @@ -286,135 +293,164 @@ def validate_update_connection_config(logger, config): def validate_file_from_request(file_input: FileInput): if file_input is None: + log_error_log(SkyflowMessages.Error.INVALID_FILE_INPUT.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - + has_file = hasattr(file_input, FileUploadField.FILE) and file_input.file is not None has_file_path = hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None - + # Must provide exactly one of file or file_path if (has_file and has_file_path) or (not has_file and not has_file_path): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value, invalid_input_error_code) - + if has_file: file = file_input.file # Validate file object has required attributes - if not hasattr(file, FileUploadField.FILE_NAME) or not isinstance(file.name, str) or not file.name.strip(): + if not hasattr(file, FileUploadField.NAME) or not isinstance(file.name, str) or not file.name.strip(): + log_error_log(SkyflowMessages.Error.INVALID_FILE_TYPE.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) - + # Validate file name file_name, _ = os.path.splitext(os.path.basename(file.name)) if not file_name or not file_name.strip(): + log_error_log(SkyflowMessages.Error.INVALID_FILE_NAME.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_NAME.value, invalid_input_error_code) - + elif has_file_path: file_path = file_input.file_path if not isinstance(file_path, str) or not file_path.strip(): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) - + if not os.path.exists(file_path) or not os.path.isfile(file_path): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if not hasattr(request, FileUploadField.FILE) or request.file is None: + log_error_log(SkyflowMessages.Error.INVALID_FILE_INPUT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - + # Validate file input first validate_file_from_request(request.file) # Optional: entities if hasattr(request, DeidentifyFileRequestField.ENTITIES) and request.entities is not None: if not isinstance(request.entities, list): + log_error_log(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) if not all(isinstance(entity, DetectEntities) for entity in request.entities): + log_error_log(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) # Optional: allow_regex_list if hasattr(request, DeidentifyFileRequestField.ALLOW_REGEX_LIST) and request.allow_regex_list is not None: if not isinstance(request.allow_regex_list, list) or not all(isinstance(x, str) for x in request.allow_regex_list): + log_error_log(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Optional: restrict_regex_list if hasattr(request, DeidentifyFileRequestField.RESTRICT_REGEX_LIST) and request.restrict_regex_list is not None: if not isinstance(request.restrict_regex_list, list) or not all(isinstance(x, str) for x in request.restrict_regex_list): + log_error_log(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) # Optional: token_format if request.token_format is not None and not isinstance(request.token_format, TokenFormat): + log_error_log(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, invalid_input_error_code) # Optional: transformations if request.transformations is not None and not isinstance(request.transformations, Transformations): + log_error_log(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) # Optional: output_processed_image if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE) and request.output_processed_image is not None: if not isinstance(request.output_processed_image, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, invalid_input_error_code) # Optional: output_ocr_text if hasattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT) and request.output_ocr_text is not None: if not isinstance(request.output_ocr_text, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, invalid_input_error_code) # Optional: masking_method if hasattr(request, DeidentifyFileRequestField.MASKING_METHOD) and request.masking_method is not None: if not isinstance(request.masking_method, MaskingMethod): + log_error_log(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, invalid_input_error_code) # Optional: pixel_density if hasattr(request, DeidentifyFileRequestField.PIXEL_DENSITY) and request.pixel_density is not None: if not isinstance(request.pixel_density, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, invalid_input_error_code) # Optional: max_resolution if hasattr(request, DeidentifyFileRequestField.MAX_RESOLUTION) and request.max_resolution is not None: if not isinstance(request.max_resolution, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, invalid_input_error_code) # Optional: output_processed_audio if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO) and request.output_processed_audio is not None: if not isinstance(request.output_processed_audio, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, invalid_input_error_code) # Optional: output_transcription if hasattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION) and request.output_transcription is not None: if not isinstance(request.output_transcription, DetectOutputTranscriptions): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, invalid_input_error_code) # Optional: bleep if hasattr(request, DeidentifyFileRequestField.BLEEP) and request.bleep is not None: if not isinstance(request.bleep, Bleep): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, invalid_input_error_code) - + # Validate gain if request.bleep.gain is not None and not isinstance(request.bleep.gain, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_GAIN.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_GAIN.value, invalid_input_error_code) - + # Validate frequency if request.bleep.frequency is not None and not isinstance(request.bleep.frequency, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value, invalid_input_error_code) - + # Validate start_padding if request.bleep.start_padding is not None and not isinstance(request.bleep.start_padding, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value, invalid_input_error_code) - + # Validate stop_padding if request.bleep.stop_padding is not None and not isinstance(request.bleep.stop_padding, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, invalid_input_error_code) # Optional: output_directory if hasattr(request, DeidentifyFileRequestField.OUTPUT_DIRECTORY) and request.output_directory is not None: if not isinstance(request.output_directory, str): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, invalid_input_error_code) if not os.path.isdir(request.output_directory): + log_error_log(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), logger) raise SkyflowError(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), invalid_input_error_code) # Optional: wait_time if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_WAIT_TIME.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) if request.wait_time < 0 or request.wait_time > Detect.WAIT_TIME: + log_error_log(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, logger) raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): @@ -429,7 +465,7 @@ def validate_insert_request(logger, request): log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) - if not len(request.values): + if not request.values: log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code) @@ -439,7 +475,7 @@ def validate_insert_request(logger, request): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value(RequestOperation.INSERT), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) if request.homogeneous is not None and not isinstance(request.homogeneous, bool): @@ -471,7 +507,7 @@ def validate_insert_request(logger, request): logger=logger) if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value(RequestOperation.INSERT), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -481,14 +517,14 @@ def validate_insert_request(logger, request): raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE_STRICT: - if len(request.values) != len(request.tokens): + if not request.tokens or len(request.values) != len(request.tokens): log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) for v, t in zip(request.values, request.tokens): if set(v.keys()) != set(t.keys()): log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format(RequestOperation.INSERT), logger=logger) - raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.MISMATCH_OF_FIELDS_AND_TOKENS.value, invalid_input_error_code) def validate_delete_request(logger, request): if not isinstance(request.table, str): @@ -503,21 +539,21 @@ def validate_delete_request(logger, request): raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code) def validate_query_request(logger, request): - if not request.query: - log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger = logger) - raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not isinstance(request.query, str): query_type = str(type(request.query)) raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code) + if not request.query: + log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger=logger) + raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) + if not request.query.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) if not request.query.upper().startswith(SqlCommand.SELECT): command = request.query - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) def validate_get_request(logger, request): redaction_type = request.redaction_type @@ -565,13 +601,13 @@ def validate_get_request(logger, request): invalid_input_error_code) if offset is not None and not isinstance(offset, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value(type(offset)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value.format(type(offset)), invalid_input_error_code) if limit is not None and not isinstance(limit, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value(type(limit)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value.format(type(limit)), invalid_input_error_code) if download_url is not None and not isinstance(download_url, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value(type(download_url)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value.format(type(download_url)), invalid_input_error_code) if column_name is not None and (not isinstance(column_name, str) or not column_name.strip()): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) @@ -603,33 +639,30 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code) def validate_update_request(logger, request): - skyflow_id = "" + if not isinstance(request.data, dict): + raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value.format(type(request.data)), invalid_input_error_code) + + if not len(request.data.items()): + raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} - try: - skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) - except Exception: + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) + if skyflow_id is None: log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) - - if not skyflow_id.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger = logger) + elif not skyflow_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger=logger) if not isinstance(request.table, str): log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not isinstance(request.return_tokens, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code) - if not isinstance(request.data, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value(type(request.data)), invalid_input_error_code) - - if not len(request.data.items()): - raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) - if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value, invalid_input_error_code) @@ -667,9 +700,9 @@ def validate_detokenize_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code) if not isinstance(request.data, list): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - if not len(request.data): + if not request.data: log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format(RequestOperation.DETOKENIZE), logger = logger) log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.DETOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code) @@ -695,7 +728,7 @@ def validate_tokenize_request(logger, request): if not isinstance(parameters, list): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(parameters)), invalid_input_error_code) - if not len(parameters): + if not parameters: raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value, invalid_input_error_code) for i, param in enumerate(parameters): @@ -728,9 +761,7 @@ def validate_file_upload_request(logger, request): # Skyflow ID skyflow_id = getattr(request, FileUploadField.SKYFLOW_ID, None) - if skyflow_id is None: - raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) - elif skyflow_id.strip() == "": + if skyflow_id is not None and skyflow_id.strip() == "": raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format(RequestOperation.FILE_UPLOAD), invalid_input_error_code) # Column Name @@ -797,46 +828,57 @@ def validate_invoke_connection_params(logger, query_params, path_params): except TypeError: raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code) -def validate_deidentify_text_request(self, request: DeidentifyTextRequest): +def validate_deidentify_text_request(logger, request: DeidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): + log_error_log(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, invalid_input_error_code) # Validate entities if present if request.entities is not None and not isinstance(request.entities, list): + log_error_log(SkyflowMessages.Error.INVALID_ENTITIES_IN_DEIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENTITIES_IN_DEIDENTIFY.value, invalid_input_error_code) # Validate allowed_regex_list if present if request.allow_regex_list is not None and not isinstance(request.allow_regex_list, list): + log_error_log(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Validate restricted_regex_list if present if request.restrict_regex_list is not None and not isinstance(request.restrict_regex_list, list): + log_error_log(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) # Validate token_format if present if request.token_format is not None and not isinstance(request.token_format, TokenFormat): + log_error_log(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, invalid_input_error_code) # Validate transformations if present if request.transformations is not None and not isinstance(request.transformations, Transformations): + log_error_log(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) -def validate_reidentify_text_request(self, request: ReidentifyTextRequest): +def validate_reidentify_text_request(logger, request: ReidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): + log_error_log(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, invalid_input_error_code) # Validate redacted_entities if present if request.redacted_entities is not None and not isinstance(request.redacted_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) # Validate masked_entities if present if request.masked_entities is not None and not isinstance(request.masked_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) # Validate plain_text_entities if present if request.plain_text_entities is not None and not isinstance(request.plain_text_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) -def validate_get_detect_run_request(self, request: GetDetectRunRequest): - if not request.run_id or not isinstance(request.run_id, str) or not request.run_id.strip(): +def validate_get_detect_run_request(logger, request: GetDetectRunRequest): + if request.run_id is None or not isinstance(request.run_id, str) or not request.run_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_RUN_ID.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RUN_ID.value, invalid_input_error_code) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index c64e8c6a..8023646c 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -16,6 +16,9 @@ def __init__(self, config): self.__logger = None self.__is_config_updated = False self.__bearer_token = None + self.__credentials = None + self.__vault_url = None + self.__is_static_token = None def set_common_skyflow_credentials(self, credentials): self.__common_skyflow_credentials = credentials @@ -25,16 +28,27 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger = self.__logger) - token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), - self.__config.get(ConfigField.ENV), - self.__config.get(ConfigField.VAULT_ID), - logger = self.__logger) - self.initialize_api_client(vault_url, token) - - def initialize_api_client(self, vault_url, token): - self.__api_client = Skyflow(base_url=vault_url, token=token) + if self.__api_client is not None and not self.__is_config_updated: + if self.__is_static_token: + return + if self.__bearer_token is not None and not is_expired(self.__bearer_token): + return + + needs_reinit = self.__api_client is None or self.__is_config_updated + if needs_reinit: + self.__credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger=self.__logger) + self.__vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), + logger=self.__logger) + self.__is_static_token = CredentialField.TOKEN in self.__credentials or CredentialField.API_KEY in self.__credentials + bearer_token = self.get_bearer_token(self.__credentials) + if needs_reinit: + self.initialize_api_client(self.__vault_url, bearer_token) + + def initialize_api_client(self, vault_url, bearer_token): + token_provider = lambda: self.__bearer_token if self.__bearer_token is not None else bearer_token # noqa: E731 + self.__api_client = Skyflow(base_url=vault_url, token=token_provider) def get_records_api(self): return self.__api_client.records @@ -64,14 +78,13 @@ def get_bearer_token(self, credentials): OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), OptionField.CTX: self.__config.get(OptionField.CTX) } - if "token_uri" in credentials and credentials.get("token_uri"): - options["token_uri"] = credentials.get("token_uri") + if CredentialField.TOKEN_URI_OPTION in credentials and credentials.get(CredentialField.TOKEN_URI_OPTION): + options[CredentialField.TOKEN_URI_OPTION] = credentials.get(CredentialField.TOKEN_URI_OPTION) - if self.__bearer_token is None or self.__is_config_updated: + if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token): if CredentialField.PATH in credentials: - path = credentials.get(CredentialField.PATH) self.__bearer_token, _ = generate_bearer_token( - path, + credentials.get(CredentialField.PATH), options, self.__logger ) @@ -87,10 +100,6 @@ def get_bearer_token(self, credentials): else: log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger) - if is_expired(self.__bearer_token): - self.__is_config_updated = True - raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - return self.__bearer_token def update_config(self, config): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 76dbfaeb..2ce0c104 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,7 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest -from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader, OptionField, ConfigField from skyflow.utils import get_credentials @@ -16,11 +16,11 @@ def __init__(self, vault_client): def invoke(self, request: InvokeConnectionRequest): log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) config = self.__vault_client.get_config() - connection_url = config.get("connection_url") + connection_url = config.get(OptionField.CONNECTION_URL) invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) - credentials = get_credentials(config.get("credentials"), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) + credentials = get_credentials(config.get(ConfigField.CREDENTIALS), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) bearer_token = self.__vault_client.get_bearer_token(credentials) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index c6ef2fb1..f12b6215 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,8 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, - FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField) +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField, Detect as DetectConstants) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -30,7 +30,7 @@ def __get_headers(self): } return headers - def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: + def __build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: deidentify_text_body = {} parsed_entity_types = request.entities @@ -43,7 +43,7 @@ def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[ return deidentify_text_body - def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: + def __build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: parsed_format = Format( redacted=request.redacted_entities, masked=request.masked_entities, @@ -57,13 +57,13 @@ def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[ def _get_file_extension(self, filename: str): return filename.split('.')[-1].lower() if '.' in filename else '' - def __poll_for_processed_file(self, run_id, max_wait_time=64): - max_wait_time = 64 if max_wait_time is None else max_wait_time + def __poll_for_processed_file(self, run_id, max_wait_time=None): + max_wait_time = DetectConstants.WAIT_TIME if max_wait_time is None else max_wait_time files_api = self.__vault_client.get_detect_file_api().with_raw_response current_wait_time = 1 # Start with 1 second try: while True: - response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data + response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data status = response.status if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: @@ -80,7 +80,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: - raise e + handle_exception(e, self.__vault_client.get_logger()) def __save_deidentify_file_response_output(self, response: DetectRunsResponse, output_directory: str, original_file_name: str, name_without_ext: str): if not response or not hasattr(response, DeidentifyField.OUTPUT) or not response.output or not output_directory: @@ -94,6 +94,7 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o base_original_filename = os.path.basename(original_file_name) base_name_without_ext = os.path.splitext(base_original_filename)[0] + real_output_dir = os.path.realpath(output_directory) for idx, output in enumerate(output_list): try: @@ -105,14 +106,25 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o continue decoded_data = base64.b64decode(processed_file) - + + # Sanitize extension from API response to prevent path traversal (CWE-22). + # Avoid os.path.basename here to keep basename mock-free in tests. + safe_ext = None + if processed_file_extension: + raw_ext = str(processed_file_extension).replace('\\', '/').split('/')[-1].lstrip('.') + safe_ext = ''.join(c for c in raw_ext if c.isalnum() or c in ('-', '_')) or 'bin' + if idx == 0 or processed_file_type == DeidentifyField.REDACTED_FILE: output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) - if processed_file_extension: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") + if safe_ext: + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext}") else: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") - + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext or 'bin'}") + + if not os.path.realpath(output_file_name).startswith(real_output_dir + os.sep): + log_error_log(SkyflowMessages.ErrorLogs.SAVING_DEIDENTIFY_FILE_FAILED.value, self.__vault_client.get_logger()) + continue + with open(output_file_name, 'wb') as f: f.write(decoded_data) except Exception as e: @@ -166,16 +178,16 @@ def output_to_dict_list(output): extension = first_output.get(DeidentifyField.EXTENSION, None) if base64_string is not None: - file_bytes = base64.b64decode(base64_string) - file_obj = io.BytesIO(file_bytes) - file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE + file_bytes = base64.b64decode(base64_string) + file_obj = io.BytesIO(file_bytes) + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get(DeidentifyField.TYPE, DetectStatus.UNKNOWN), + type=first_output.get(DeidentifyField.TYPE, None), extension=extension, word_count=word_count, char_count=char_count, @@ -195,6 +207,7 @@ def __get_token_format(self, request): DeidentifyField.DEFAULT: getattr(request.token_format, DeidentifyField.DEFAULT, None), DeidentifyField.ENTITY_UNQ_COUNTER: getattr(request.token_format, DeidentifyField.ENTITY_UNIQUE_COUNTER, None), DeidentifyField.ENTITY_ONLY: getattr(request.token_format, DeidentifyField.ENTITY_ONLY, None), + DeidentifyField.VAULT_TOKEN: getattr(request.token_format, DeidentifyField.VAULT_TOKEN, None) } def __get_transformations(self, request): @@ -217,7 +230,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - deidentify_text_body = self.___build_deidentify_text_body(request) + deidentify_text_body = self.__build_deidentify_text_body(request) try: log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) @@ -229,7 +242,7 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo restrict_regex=deidentify_text_body[DeidentifyField.RESTRICT_REGEX], token_type=deidentify_text_body[DeidentifyField.TOKEN_TYPE], transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) deidentify_text_response = parse_deidentify_text_response(api_response) log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -245,7 +258,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - reidentify_text_body = self.___build_reidentify_text_body(request) + reidentify_text_body = self.__build_reidentify_text_body(request) try: log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) @@ -253,7 +266,7 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo vault_id=self.__vault_client.get_vault_id(), text=reidentify_text_body[DeidentifyField.TEXT], format=reidentify_text_body[DeidentifyField.FORMAT], - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) reidentify_text_response = parse_reidentify_text_response(api_response) log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) @@ -265,14 +278,16 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo def __get_file_from_request(self, request: DeidentifyFileRequest): file_input = request.file - - # Check for file + if hasattr(file_input, FileUploadField.FILE) and file_input.file is not None: return file_input.file - - # Check for file_path if file is not provided + if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: - return open(file_input.file_path, 'rb') + with open(file_input.file_path, 'rb') as f: + content = f.read() + bio = io.BytesIO(content) + bio.name = file_input.file_path + return bio def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) @@ -297,12 +312,13 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio + bleep = request.bleep api_kwargs = { OptionField.VAULT_ID: self.__vault_client.get_vault_id(), DeidentifyField.FILE: req_file, @@ -313,11 +329,11 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION: getattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION, None), DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO, None), - DeidentifyField.BLEEP_GAIN: getattr(request, DeidentifyFileRequestField.BLEEP, None).gain if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_FREQUENCY: getattr(request, DeidentifyFileRequestField.BLEEP, None).frequency if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_START_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).start_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.BLEEP_STOP_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).stop_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.BLEEP_GAIN: bleep.gain if bleep is not None else None, + DeidentifyField.BLEEP_FREQUENCY: bleep.frequency if bleep is not None else None, + DeidentifyField.BLEEP_START_PADDING: bleep.start_padding if bleep is not None else None, + DeidentifyField.BLEEP_STOP_PADDING: bleep.stop_padding if bleep is not None else None, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension == FileExtension.PDF: @@ -331,8 +347,8 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyFileRequestField.MAX_RESOLUTION: getattr(request, DeidentifyFileRequestField.MAX_RESOLUTION, None), - DeidentifyFileRequestField.PIXEL_DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyFileRequestField.DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: @@ -348,7 +364,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyFileRequestField.MASKING_METHOD: getattr(request, DeidentifyFileRequestField.MASKING_METHOD, None), DeidentifyFileRequestField.OUTPUT_OCR_TEXT: getattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT, None), DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE, None), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: @@ -361,7 +377,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: @@ -374,7 +390,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: @@ -387,7 +403,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } elif file_extension in [FileExtension.JSON, FileExtension.XML]: @@ -401,7 +417,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } else: @@ -415,7 +431,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): DeidentifyField.ALLOW_REGEX: request.allow_regex_list, DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), - DeidentifyField.REQUEST_OPTIONS: self.__get_headers() + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) @@ -424,7 +440,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == DetectStatus.SUCCESS: + if request.output_directory and processed_response.status == DetectStatus.SUCCESS and file_name: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -449,10 +465,10 @@ def get_detect_run(self, request: GetDetectRunRequest): response = files_api.get_run( run_id, vault_id=self.__vault_client.get_vault_id(), - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) if response.data.status == DetectStatus.IN_PROGRESS: - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)) + parsed_response = DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 856a1961..7d51ee83 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -89,10 +89,7 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return None def __get_headers(self): - headers = { - SKY_META_DATA_HEADER: json.dumps(get_metrics()) - } - return headers + return {SKY_META_DATA_HEADER: json.dumps(get_metrics())} def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.VALIDATE_INSERT_REQUEST.value, self.__vault_client.get_logger()) @@ -106,11 +103,11 @@ def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.INSERT_TRIGGERED.value, self.__vault_client.get_logger()) if request.continue_on_error: api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(), - records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options=self.__get_headers()) + records=insert_body, continue_on_error=request.continue_on_error, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) else: api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), - request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) + request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options={'additional_headers': self.__get_headers()}) insert_response = parse_insert_response(api_response, request.continue_on_error) log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) @@ -138,7 +135,7 @@ def update(self, request: UpdateRequest): record=record, tokenization=request.return_tokens, byot=request.token_mode.value, - request_options = self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger()) update_response = parse_update_record_response(api_response) @@ -159,7 +156,7 @@ def delete(self, request: DeleteRequest): self.__vault_client.get_vault_id(), request.table, skyflow_ids=request.ids, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger()) delete_response = parse_delete_response(api_response) @@ -189,7 +186,7 @@ def get(self, request: GetRequest): download_url=request.download_url, column_name=request.column_name, column_values=request.column_values, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger()) get_response = parse_get_response(api_response) @@ -209,7 +206,7 @@ def query(self, request: QueryRequest): api_response = query_api.query_service_execute_query( self.__vault_client.get_vault_id(), query=request.query, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger()) query_response = parse_query_response(api_response) @@ -237,7 +234,7 @@ def detokenize(self, request: DetokenizeRequest): self.__vault_client.get_vault_id(), detokenization_parameters=tokens_list, continue_on_error = request.continue_on_error, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) detokenize_response = parse_detokenize_response(api_response) @@ -262,7 +259,7 @@ def tokenize(self, request: TokenizeRequest): api_response = tokens_api.record_service_tokenize( self.__vault_client.get_vault_id(), tokenization_parameters=records_list, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) tokenize_response = parse_tokenize_response(api_response) log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) @@ -285,7 +282,7 @@ def upload_file(self, request: FileUploadRequest): file=self.__get_file_for_file_upload(request), skyflow_id=request.skyflow_id, return_file_metadata= False, - request_options=self.__get_headers() + request_options={'additional_headers': self.__get_headers()} ) log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/data/_file_upload_request.py b/skyflow/vault/data/_file_upload_request.py index d1bd4a44..6a632b67 100644 --- a/skyflow/vault/data/_file_upload_request.py +++ b/skyflow/vault/data/_file_upload_request.py @@ -1,14 +1,28 @@ -from typing import BinaryIO +import warnings +from typing import BinaryIO, Optional + +from skyflow.utils import SkyflowMessages + class FileUploadRequest: def __init__(self, table: str, - skyflow_id: str, - column_name: str, - file_path: str= None, - base64: str= None, - file_object: BinaryIO= None, - file_name: str= None): + *args, + column_name: Optional[str] = None, + skyflow_id: Optional[str] = None, + file_path: Optional[str] = None, + base64: Optional[str] = None, + file_object: Optional[BinaryIO] = None, + file_name: Optional[str] = None): + if args: + warnings.warn( + SkyflowMessages.Warning.FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED.value, + DeprecationWarning, + stacklevel=2, + ) + # Old positional order was: (table, skyflow_id, column_name, ...) + skyflow_id = args[0] if args else skyflow_id + column_name = args[1] if len(args) > 1 else column_name self.table = table self.skyflow_id = skyflow_id self.column_name = column_name diff --git a/skyflow/vault/data/_get_response.py b/skyflow/vault/data/_get_response.py index cf1b0805..a1640254 100644 --- a/skyflow/vault/data/_get_response.py +++ b/skyflow/vault/data/_get_response.py @@ -1,6 +1,6 @@ class GetResponse: def __init__(self, data=None, errors = None): - self.data = data if data else [] + self.data = data if data is not None else [] self.errors = errors def __repr__(self): diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index b340e21c..e56f2113 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -1,22 +1,24 @@ import io +from typing import Optional from skyflow.vault.detect._file import File class DeidentifyFileResponse: def __init__( self, - file_base64: str = None, - file: io.BytesIO = None, - type: str = None, - extension: str = None, - word_count: int = None, - char_count: int = None, - size_in_kb: float = None, - duration_in_seconds: float = None, - page_count: int = None, - slide_count: int = None, - entities: list = None, # list of dicts with keys 'file' and 'extension' - run_id: str = None, - status: str = None, + file_base64: Optional[str] = None, + file: Optional[io.BytesIO] = None, + type: Optional[str] = None, + extension: Optional[str] = None, + word_count: Optional[int] = None, + char_count: Optional[int] = None, + size_in_kb: Optional[float] = None, + duration_in_seconds: Optional[float] = None, + page_count: Optional[int] = None, + slide_count: Optional[int] = None, + entities: Optional[list] = None, + run_id: Optional[str] = None, + status: Optional[str] = None, + errors: Optional[list] = None, ): self.file_base64 = file_base64 self.file = File(file) if file else None @@ -31,6 +33,7 @@ def __init__( self.entities = entities if entities is not None else [] self.run_id = run_id self.status = status + self.errors = errors def __repr__(self): return ( @@ -40,7 +43,7 @@ def __repr__(self): f"char_count={self.char_count!r}, size_in_kb={self.size_in_kb!r}, " f"duration_in_seconds={self.duration_in_seconds!r}, page_count={self.page_count!r}, " f"slide_count={self.slide_count!r}, entities={self.entities!r}, " - f"run_id={self.run_id!r}, status={self.status!r})" + f"run_id={self.run_id!r}, status={self.status!r}, errors={self.errors!r})" ) def __str__(self): diff --git a/skyflow/vault/detect/_deidentify_text_response.py b/skyflow/vault/detect/_deidentify_text_response.py index cdb6632e..227b43bc 100644 --- a/skyflow/vault/detect/_deidentify_text_response.py +++ b/skyflow/vault/detect/_deidentify_text_response.py @@ -1,19 +1,21 @@ -from typing import List +from typing import List, Optional from ._entity_info import EntityInfo class DeidentifyTextResponse: - def __init__(self, + def __init__(self, processed_text: str, - entities: List[EntityInfo], + entities: List[EntityInfo], word_count: int, - char_count: int): + char_count: int, + errors: Optional[list] = None): self.processed_text = processed_text self.entities = entities self.word_count = word_count self.char_count = char_count + self.errors = errors def __repr__(self): - return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count})" + return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count}, errors={self.errors})" def __str__(self): return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/detect/_reidentify_text_response.py b/skyflow/vault/detect/_reidentify_text_response.py index 50c3876d..73ad3f5d 100644 --- a/skyflow/vault/detect/_reidentify_text_response.py +++ b/skyflow/vault/detect/_reidentify_text_response.py @@ -1,9 +1,12 @@ +from typing import Optional + class ReidentifyTextResponse: - def __init__(self, processed_text: str): + def __init__(self, processed_text: str, errors: Optional[list] = None): self.processed_text = processed_text + self.errors = errors def __repr__(self) -> str: - return f"ReidentifyTextResponse(processed_text='{self.processed_text}')" + return f"ReidentifyTextResponse(processed_text='{self.processed_text}', errors={self.errors})" def __str__(self) -> str: return self.__repr__() \ No newline at end of file diff --git a/tests/client/test_skyflow.py b/tests/client/test_skyflow.py index 3e3681bb..1122448a 100644 --- a/tests/client/test_skyflow.py +++ b/tests/client/test_skyflow.py @@ -1,42 +1,43 @@ import unittest -from unittest.mock import patch +import warnings +from unittest.mock import patch, Mock from skyflow import LogLevel, Env from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow import Skyflow +from skyflow.vault.client.client import VaultClient +from skyflow.vault.data import FileUploadRequest VALID_VAULT_CONFIG = { "vault_id": "VAULT_ID", "cluster_id": "CLUSTER_ID", "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_VAULT_CONFIG = { "cluster_id": "CLUSTER_ID", # Missing vault_id "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } VALID_CONNECTION_CONFIG = { "connection_id": "CONNECTION_ID", "connection_url": "https://CONNECTION_URL", - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_CONNECTION_CONFIG = { "connection_url": "https://CONNECTION_URL", # Missing connection_id - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } -VALID_CREDENTIALS = { - "path": "/path/to/valid_credentials.json" -} +VALID_CREDENTIALS = {"path": "/path/to/valid_credentials.json"} -class TestSkyflow(unittest.TestCase): +class TestSkyflow(unittest.TestCase): def setUp(self): self.builder = Skyflow.builder() @@ -49,8 +50,10 @@ def test_add_already_exists_vault_config(self): builder = self.builder.add_vault_config(VALID_VAULT_CONFIG) with self.assertRaises(SkyflowError) as context: builder.add_vault_config(VALID_VAULT_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id"))) - + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id")), + ) def test_add_vault_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -61,11 +64,11 @@ def test_add_vault_config_invalid(self): def test_remove_vault_config_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - result = self.builder.remove_vault_config(VALID_VAULT_CONFIG['vault_id']) + result = self.builder.remove_vault_config(VALID_VAULT_CONFIG["vault_id"]) - self.assertNotIn(VALID_VAULT_CONFIG['vault_id'], self.builder._Builder__vault_configs) + self.assertNotIn(VALID_VAULT_CONFIG["vault_id"], self.builder._Builder__vault_configs) - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_vault_config_invalid(self, mock_log_error): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -73,8 +76,7 @@ def test_remove_vault_config_invalid(self, mock_log_error): self.builder.remove_vault_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_VAULT_ID.value) - - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_vault_config_valid(self, mock_validate): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -94,7 +96,7 @@ def test_get_vault(self): def test_get_vault_with_vault_id_none(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - vault = self.builder.get_vault_config(None) + vault = self.builder.get_vault_config(None) config = vault.get("vault_client").get_config() self.assertEqual(self.builder._Builder__vault_list[0], config) @@ -107,19 +109,23 @@ def test_get_vault_with_empty_vault_list_when_vault_id_is_none_raises_error(self def test_get_vault_with_invalid_vault_id_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_vault_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_vault_with_invalid_vault_id_and_non_empty_list_raises_error(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_vault_id') - - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id")) + self.builder.get_vault_config("invalid_vault_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id"), + ) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_build_calls_validate_vault_config(self, mock_validate_vault_config): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -143,7 +149,9 @@ def test_add_already_exists_connection_config(self): with self.assertRaises(SkyflowError) as context: builder.add_connection_config(VALID_CONNECTION_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id) + ) def test_add_connection_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -158,8 +166,7 @@ def test_remove_connection_config_valid(self): self.assertNotIn(VALID_CONNECTION_CONFIG.get("connection_id"), self.builder._Builder__connection_configs) - - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_connection_config_invalid(self, mock_log_error): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -167,7 +174,7 @@ def test_remove_connection_config_invalid(self, mock_log_error): self.builder.remove_connection_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONNECTION_ID.value) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_connection_config_valid(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -194,16 +201,21 @@ def test_get_connection_config_with_connection_id_none(self): def test_get_connection_with_empty_connection_list_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_connection_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_connection_with_invalid_connection_id_raises_error(self): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_connection_id') + self.builder.get_connection_config("invalid_connection_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_connection_id')) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_connection_id"), + ) def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(self): self.builder.build() @@ -212,13 +224,12 @@ def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(s self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS.value) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_build_calls_validate_connection_config(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() mock_validate.assert_called_once_with(self.builder._Builder__logger, VALID_CONNECTION_CONFIG) - def test_build_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG).add_connection_config(VALID_CONNECTION_CONFIG) client = self.builder.build() @@ -236,30 +247,31 @@ def test_invalid_credentials(self): self.assertEqual(VALID_CREDENTIALS, self.builder._Builder__skyflow_credentials) self.assertEqual(builder, self.builder) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_skyflow_client_add_remove_vault_config(self, mock_validate_vault_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['vault_id'] = "VAULT_ID" + new_config["vault_id"] = "VAULT_ID" skyflow_client.add_vault_config(new_config) - assert mock_validate_vault_config.call_count == 2 + self.assertEqual(mock_validate_vault_config.call_count, 2) - self.assertEqual("VAULT_ID", - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id")) + self.assertEqual("VAULT_ID", skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id")) - skyflow_client.remove_vault_config(new_config['vault_id']) + skyflow_client.remove_vault_config(new_config["vault_id"]) with self.assertRaises(SkyflowError) as context: - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id") + skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format( - new_config['vault_id'])) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(new_config["vault_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['env'] = Env.SANDBOX + new_config["env"] = Env.SANDBOX skyflow_client.update_vault_config(new_config) mock_update_config.assert_called_once() @@ -267,29 +279,33 @@ def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): self.assertEqual(VALID_VAULT_CONFIG.get("vault_id"), vault.get("vault_id")) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_skyflow_client_add_remove_connection_config(self, mock_validate_connection_config): skyflow_client = self.builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_id'] = "CONNECTION_ID" + new_config["connection_id"] = "CONNECTION_ID" skyflow_client.add_connection_config(new_config) - assert mock_validate_connection_config.call_count == 2 - self.assertEqual("CONNECTION_ID", skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id")) + self.assertEqual(mock_validate_connection_config.call_count, 2) + self.assertEqual( + "CONNECTION_ID", skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + ) skyflow_client.remove_connection_config("CONNECTION_ID") with self.assertRaises(SkyflowError) as context: - skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id") - - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config['connection_id'])) + skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config["connection_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_connection_config(self, mock_update_config): builder = self.builder skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_url'] = 'updated_url' + new_config["connection_url"] = "updated_url" skyflow_client.update_connection_config(new_config) mock_update_config.assert_called_once() @@ -305,28 +321,174 @@ def test_skyflow_add_and_update_skyflow_credentials(self): self.assertEqual(VALID_CREDENTIALS, builder._Builder__skyflow_credentials) new_credentials = VALID_CREDENTIALS.copy() - new_credentials['path'] = 'path/to/new_credentials' + new_credentials["path"] = "path/to/new_credentials" skyflow_client.update_skyflow_credentials(new_credentials) self.assertEqual(new_credentials, builder._Builder__skyflow_credentials) - def test_skyflow_add_and_update_log_level(self): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() + skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() skyflow_client.set_log_level(LogLevel.INFO) self.assertEqual(LogLevel.INFO, builder._Builder__log_level) - skyflow_client.update_log_level(LogLevel.ERROR) - self.assertEqual(LogLevel.ERROR, builder._Builder__log_level) - - - @patch('skyflow.client.Skyflow.Builder.get_vault_config') + @patch("skyflow.client.Skyflow.Builder.get_vault_config") def test_skyflow_vault_and_connection_method(self, mock_get_vault_config): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + skyflow_client = ( + builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + ) skyflow_client.vault() skyflow_client.connection() - mock_get_vault_config.assert_called_once() \ No newline at end of file + mock_get_vault_config.assert_called_once() + + def test_detect_returns_detect_controller(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + from skyflow.vault.controller import Detect + result = skyflow_client.detect() + self.assertIsInstance(result, Detect) + + def test_detect_with_explicit_vault_id(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + from skyflow.vault.controller import Detect + result = skyflow_client.detect(VALID_VAULT_CONFIG["vault_id"]) + self.assertIsInstance(result, Detect) + + def test_detect_with_invalid_vault_id_raises_error(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + with self.assertRaises(SkyflowError) as context: + skyflow_client.detect("invalid_vault_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id"), + ) + + @patch("skyflow.vault.client.client.VaultClient.update_config") + def test_update_vault_config_with_invalid_vault_id_raises_error(self, _mock): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + invalid_config = VALID_VAULT_CONFIG.copy() + invalid_config["vault_id"] = "non_existent_vault_id" + with self.assertRaises(SkyflowError) as context: + skyflow_client.update_vault_config(invalid_config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("non_existent_vault_id"), + ) + + @patch("skyflow.vault.client.client.VaultClient.update_config") + def test_update_connection_config_with_invalid_connection_id_raises_error(self, _mock): + skyflow_client = self.builder.add_connection_config(VALID_CONNECTION_CONFIG).build() + invalid_config = VALID_CONNECTION_CONFIG.copy() + invalid_config["connection_id"] = "non_existent_connection_id" + with self.assertRaises(SkyflowError) as context: + skyflow_client.update_connection_config(invalid_config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("non_existent_connection_id"), + ) + + +class TestVaultClient(unittest.TestCase): + def _make_client(self): + client = VaultClient({"vault_id": "test_vault"}) + client._VaultClient__api_client = Mock() + return client + + def test_get_detect_text_api_returns_strings(self): + client = self._make_client() + result = client.get_detect_text_api() + self.assertEqual(result, client._VaultClient__api_client.strings) + + def test_get_detect_file_api_returns_files(self): + client = self._make_client() + result = client.get_detect_file_api() + self.assertEqual(result, client._VaultClient__api_client.files) + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.is_expired", return_value=True) + def test_get_bearer_token_passes_token_uri_option(self, _mock_expired, mock_gen): + mock_gen.return_value = ("test_token", "bearer") + client = VaultClient({"vault_id": "test_vault"}) + credentials = { + "credentials_string": '{"clientID":"id","privateKey":"pk","keyID":"kid","tokenURI":"https://token.uri"}', + "token_uri": "https://custom-token-uri.com/token", + } + client.get_bearer_token(credentials) + options_passed = mock_gen.call_args[0][1] + self.assertIn("token_uri", options_passed) + self.assertEqual(options_passed["token_uri"], "https://custom-token-uri.com/token") + + +class TestUpdateLogLevelDeprecation(unittest.TestCase): + def _build_client(self): + return Skyflow.builder().add_vault_config(VALID_VAULT_CONFIG).build() + + def test_update_log_level_emits_deprecation_warning(self): + client = self._build_client() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client.update_log_level(LogLevel.INFO) + deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] + self.assertGreaterEqual(len(deprecation_warnings), 1) + self.assertTrue(any("set_log_level" in str(w.message) for w in deprecation_warnings)) + + def test_update_log_level_warning_points_at_caller(self): + client = self._build_client() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client.update_log_level(LogLevel.INFO) + self.assertEqual(caught[0].filename, __file__) + + def test_update_log_level_delegates_to_set_log_level(self): + client = self._build_client() + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + client.update_log_level(LogLevel.INFO) + self.assertEqual(client.get_log_level(), LogLevel.INFO) + + +class TestFileUploadRequestDeprecation(unittest.TestCase): + def test_keyword_args_no_warning(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + req = FileUploadRequest( + table="table", + column_name="col", + skyflow_id="sky123", + ) + self.assertEqual(len(caught), 0) + self.assertEqual(req.table, "table") + self.assertEqual(req.column_name, "col") + self.assertEqual(req.skyflow_id, "sky123") + + def test_old_positional_order_emits_deprecation_warning(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + req = FileUploadRequest("table", "sky123", "col") + self.assertEqual(len(caught), 1) + self.assertTrue(issubclass(caught[0].category, DeprecationWarning)) + self.assertIn("FileUploadRequest", str(caught[0].message)) + + def test_old_positional_order_remaps_args_correctly(self): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + req = FileUploadRequest("table", "sky123", "col") + self.assertEqual(req.skyflow_id, "sky123") + self.assertEqual(req.column_name, "col") + + def test_old_positional_order_warning_points_at_caller(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + FileUploadRequest("table", "sky123", "col") + self.assertEqual(caught[0].filename, __file__) + + def test_single_positional_arg_emits_warning_and_sets_skyflow_id(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + req = FileUploadRequest("table", "sky123") + self.assertEqual(len(caught), 1) + self.assertTrue(issubclass(caught[0].category, DeprecationWarning)) + self.assertEqual(req.skyflow_id, "sky123") + self.assertIsNone(req.column_name) diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index ca82527a..505a7261 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -5,35 +5,57 @@ from unittest.mock import patch import os from skyflow.error import SkyflowError -from skyflow.service_account import is_expired, generate_bearer_token, \ - generate_bearer_token_from_creds +from skyflow.service_account import is_expired, generate_bearer_token, generate_bearer_token_from_creds from skyflow.utils import SkyflowMessages -from skyflow.service_account._utils import get_service_account_token, get_signed_jwt, generate_signed_data_tokens, get_signed_data_token_response_object, generate_signed_data_tokens_from_creds +from skyflow.service_account._utils import ( + get_service_account_token, + get_signed_jwt, + generate_signed_data_tokens, + get_signed_data_token_response_object, + generate_signed_data_tokens_from_creds, + _validate_and_resolve_ctx, + _normalize_credentials, + get_signed_tokens, +) creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) VALID_CREDENTIALS_STRING = json.dumps(credentials) -CREDENTIALS_WITHOUT_CLIENT_ID = { - 'privateKey': 'private_key' -} +CREDENTIALS_WITHOUT_CLIENT_ID = {"privateKey": "private_key"} -CREDENTIALS_WITHOUT_KEY_ID = { - 'privateKey': 'private_key', - 'clientID': 'client_id' -} +CREDENTIALS_WITHOUT_KEY_ID = {"privateKey": "private_key", "clientID": "client_id"} -CREDENTIALS_WITHOUT_TOKEN_URI = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id' -} +CREDENTIALS_WITHOUT_TOKEN_URI = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id"} VALID_SERVICE_ACCOUNT_CREDS = credentials +# Snake-case version of the real credentials (keys remapped to snake_case) +SNAKE_CASE_CREDS = { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], +} + +SNAKE_CASE_CREDS_STRING = json.dumps( + { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], + } +) + + class TestServiceAccountUtils(unittest.TestCase): + # ── is_expired ──────────────────────────────────────────────────────────── + + def test_is_expired_none_token(self): + self.assertTrue(is_expired(None)) + def test_is_expired_empty_token(self): self.assertTrue(is_expired("")) @@ -44,7 +66,7 @@ def test_is_expired_non_expired_token(self): def test_is_expired_expired_token(self): past_time = time.time() - 1000 - token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") + token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) @patch("skyflow.utils.logger._log_helpers.log_error_log") @@ -53,6 +75,8 @@ def test_is_expired_general_exception(self, mock_jwt_decode, mock_log_error): token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) + # ── generate_bearer_token ───────────────────────────────────────────────── + @patch("builtins.open", side_effect=FileNotFoundError) def test_generate_bearer_token_invalid_file_path(self, mock_open): with self.assertRaises(SkyflowError) as context: @@ -72,6 +96,8 @@ def test_generate_bearer_token_valid_file_path(self, mock_generate_bearer_token) generate_bearer_token(creds_path) mock_generate_bearer_token.assert_called_once() + # ── generate_bearer_token_from_creds ────────────────────────────────────── + @patch("skyflow.service_account._utils.get_service_account_token") def test_generate_bearer_token_from_creds_with_valid_json_string(self, mock_generate_bearer_token): generate_bearer_token_from_creds(VALID_CREDENTIALS_STRING) @@ -82,10 +108,11 @@ def test_generate_bearer_token_from_creds_invalid_json(self): generate_bearer_token_from_creds("invalid_json") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + # ── get_service_account_token ───────────────────────────────────────────── + def test_get_service_account_token_missing_private_key(self): - incomplete_credentials = {} with self.assertRaises(SkyflowError) as context: - get_service_account_token(incomplete_credentials, {}, None) + get_service_account_token({}, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) def test_get_service_account_token_missing_client_id_key(self): @@ -107,43 +134,42 @@ def test_get_service_account_token_with_valid_credentials(self): access_token, _ = get_service_account_token(VALID_SERVICE_ACCOUNT_CREDS, {}, None) self.assertTrue(access_token) + def test_get_service_account_token_with_snake_case_creds(self): + access_token, _ = get_service_account_token(SNAKE_CASE_CREDS, {}, None) + self.assertTrue(access_token) - @patch("jwt.encode", side_effect=Exception) - def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): + def test_get_service_account_token_missing_private_key_snake(self): + creds = { + "client_id": "id", + "key_id": "kid", + "token_uri": "https://example.com", + } with self.assertRaises(SkyflowError) as context: - get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None) - self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value) - - def test_get_signed_data_token_response_object(self): - token = "sample_token" - signed_token = "signed_sample_token" - response = get_signed_data_token_response_object(signed_token, token) - self.assertEqual(response[0], token) - self.assertEqual(response[1], signed_token) - - def test_generate_signed_data_tokens_from_file_path(self): - creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") - options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'} - result = generate_signed_data_tokens(creds_path, options) - self.assertEqual(len(result), 2) + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) - def test_generate_signed_data_tokens_from_invalid_file_path(self): - options = {"data_tokens": ["token1", "token2"]} + def test_get_service_account_token_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens('credentials1.json', options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) - - def test_generate_signed_data_tokens_from_creds(self): - options = {"data_tokens": ["token1", "token2"]} - result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) - self.assertEqual(len(result), 2) + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) - def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): - options = {"data_tokens": ["token1", "token2"]} - credentials_string = '{' + def test_get_service_account_token_invalid_token_uri_in_options(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "not-a-valid-url"} with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens_from_creds(credentials_string, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + get_service_account_token(creds, options, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) @patch("skyflow.service_account._utils.AuthClient") @patch("skyflow.service_account._utils.get_signed_jwt") @@ -152,13 +178,14 @@ def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_si "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } options = {"role_ids": ["role1", "role2"]} mock_get_signed_jwt.return_value = "signed" mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value - mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), {"access_token": "token", - "token_type": "bearer"}) + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) access_token, token_type = get_service_account_token(creds, options, None) self.assertEqual(access_token, "token") self.assertEqual(token_type, "bearer") @@ -173,16 +200,18 @@ def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } mock_get_signed_jwt.return_value = "signed" mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError + mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized") with self.assertRaises(SkyflowError) as context: get_service_account_token(creds, {}, None) - self.assertEqual(context.exception.message, - SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value + ) @patch("skyflow.service_account._utils.AuthClient") @patch("skyflow.service_account._utils.get_signed_jwt") @@ -191,7 +220,7 @@ def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } mock_get_signed_jwt.return_value = "signed" mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value @@ -200,16 +229,364 @@ def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, get_service_account_token(creds, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value) + # ── get_signed_jwt ──────────────────────────────────────────────────────── + + @patch("jwt.encode", side_effect=Exception) + def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): + with self.assertRaises(SkyflowError) as context: + get_signed_jwt({}, "client_id", "key_id", "token_uri", "private_key", None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.JWT_INVALID_FORMAT.value) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_string_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": "valid_ctx"}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], "valid_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_valid_dict_ctx(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": {"role": "admin"}}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertEqual(payload["ctx"], {"role": "admin"}) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_jwt_with_empty_string_ctx_not_added(self, mock_jwt_encode): + mock_jwt_encode.return_value = "mock_token" + get_signed_jwt({"ctx": ""}, "client_id", "key_id", "token_uri", "private_key", None) + payload = mock_jwt_encode.call_args.kwargs["payload"] + self.assertNotIn("ctx", payload) + + # ── get_signed_data_token_response_object ───────────────────────────────── + + def test_get_signed_data_token_response_object(self): + token = "sample_token" + signed_token = "signed_sample_token" + response = get_signed_data_token_response_object(signed_token, token) + self.assertIsInstance(response, tuple) + self.assertEqual(response[0], token) + self.assertEqual(response[1], signed_token) + + # ── get_signed_tokens ───────────────────────────────────────────────────── + @patch("jwt.encode", side_effect=Exception("jwt error")) def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): creds = { "privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", - "tokenURI": "https://valid-url.com" + "tokenURI": "https://valid-url.com", } options = {"data_tokens": ["token1"]} with self.assertRaises(SkyflowError) as context: - from skyflow.service_account._utils import get_signed_tokens get_signed_tokens(creds, options) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) \ No newline at end of file + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) + + def test_get_signed_tokens_returns_list_one_per_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + def test_get_signed_tokens_items_are_tuples_with_token_and_signed_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + for item in result: + self.assertIsInstance(item, tuple) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[1][0], "token2") + self.assertTrue(result[0][1].startswith("signed_token_")) + self.assertTrue(result[1][1].startswith("signed_token_")) + + def test_get_signed_tokens_returns_list_single_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + def test_get_signed_tokens_empty_data_tokens_returns_empty_list(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": []}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_string_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": "my_ctx"}) + call_args = mock_jwt_encode.call_args + claims = call_args[0][0] if call_args[0] else call_args.kwargs.get("args", [None])[0] + # jwt.encode(claims, key, algorithm=...) — first positional arg is claims + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], "my_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_dict_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + ctx_dict = {"role": "admin", "dept": "eng"} + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ctx_dict}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], ctx_dict) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_empty_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ""}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_none_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": None}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + def test_get_signed_tokens_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_missing_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_with_snake_case_creds(self): + result = get_signed_tokens(SNAKE_CASE_CREDS, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + # ── generate_signed_data_tokens (file path) ─────────────────────────────── + + def test_generate_signed_data_tokens_from_file_path(self): + options = {"data_tokens": ["token1", "token2"], "ctx": "ctx"} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_invalid_file_path(self): + options = {"data_tokens": ["token1", "token2"]} + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens("credentials1.json", options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) + + def test_generate_signed_data_tokens_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "department": "finance"}} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 1) + + # ── generate_signed_data_tokens_from_creds (string) ────────────────────── + + def test_generate_signed_data_tokens_from_creds(self): + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): + options = {"data_tokens": ["token1", "token2"]} + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds("{", options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + + def test_generate_signed_data_tokens_from_creds_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "level": 3}} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 1) + + # ── snake_case end-to-end ───────────────────────────────────────────────── + + def test_generate_signed_data_tokens_with_snake_creds_file(self): + """generate_signed_data_tokens reads the file (camelCase) but the normalize fn is a no-op for camelCase.""" + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_snake(self): + result = generate_signed_data_tokens_from_creds(SNAKE_CASE_CREDS_STRING, options={"data_tokens": ["t1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + # ── _normalize_credentials ──────────────────────────────────────────────── + + def test_normalize_credentials_snake_case(self): + snake = { + "private_key": "pk", + "client_id": "cid", + "key_id": "kid", + "token_uri": "https://uri", + "client_name": "name", + } + result = _normalize_credentials(snake) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertEqual(result["clientName"], "name") + self.assertNotIn("private_key", result) + self.assertNotIn("client_id", result) + self.assertNotIn("key_id", result) + self.assertNotIn("token_uri", result) + self.assertNotIn("client_name", result) + + def test_normalize_credentials_camel_case_unchanged(self): + camel = { + "privateKey": "pk", + "clientID": "cid", + "keyID": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(camel) + self.assertEqual(result, camel) + + def test_normalize_credentials_mixed_keys(self): + mixed = { + "private_key": "pk", + "clientID": "cid", + "key_id": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(mixed) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertNotIn("private_key", result) + self.assertNotIn("key_id", result) + + def test_normalize_credentials_unknown_key_passes_through(self): + creds = {"unknown_field": "value", "anotherField": "val2"} + result = _normalize_credentials(creds) + self.assertEqual(result["unknown_field"], "value") + self.assertEqual(result["anotherField"], "val2") + + def test_normalize_credentials_empty_dict(self): + self.assertEqual(_normalize_credentials({}), {}) + + # ── _validate_and_resolve_ctx ───────────────────────────────────────────── + + def test_validate_and_resolve_ctx_none(self): + self.assertIsNone(_validate_and_resolve_ctx(None)) + + def test_validate_and_resolve_ctx_empty_string(self): + self.assertIsNone(_validate_and_resolve_ctx("")) + self.assertIsNone(_validate_and_resolve_ctx(" ")) + + def test_validate_and_resolve_ctx_valid_string(self): + self.assertEqual(_validate_and_resolve_ctx("user_12345"), "user_12345") + + def test_validate_and_resolve_ctx_empty_dict(self): + self.assertIsNone(_validate_and_resolve_ctx({})) + + def test_validate_and_resolve_ctx_valid_dict(self): + ctx = {"role": "admin", "department": "finance"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_alphanumeric_keys(self): + ctx = {"role_1": "admin", "dept2": "finance", "ABC_123": "value"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_hyphen(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"valid_key": "value", "invalid-key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_space(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid key": "value"}) + + def test_validate_and_resolve_ctx_dict_with_invalid_key_dot(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx({"invalid.key": "value"}) + + def test_validate_and_resolve_ctx_valid_type_int(self): + self.assertEqual(_validate_and_resolve_ctx(42), 42) + + def test_validate_and_resolve_ctx_valid_type_float(self): + self.assertEqual(_validate_and_resolve_ctx(3.14), 3.14) + + def test_validate_and_resolve_ctx_valid_type_bool_true(self): + self.assertEqual(_validate_and_resolve_ctx(True), True) + + def test_validate_and_resolve_ctx_valid_type_bool_false(self): + self.assertEqual(_validate_and_resolve_ctx(False), False) + + def test_validate_and_resolve_ctx_invalid_type_list(self): + with self.assertRaises(SkyflowError): + _validate_and_resolve_ctx(["a", "b"]) + + def test_validate_and_resolve_ctx_dict_with_mixed_value_types(self): + ctx = {"role": "admin", "level": 3, "active": True, "timestamp": "2025-12-25T10:30:00Z"} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + def test_validate_and_resolve_ctx_dict_with_nested_objects(self): + ctx = {"role": "admin", "metadata": {"level": 2, "tags": ["a", "b"]}} + self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) + + # ── additional coverage gaps ────────────────────────────────────────────── + + @patch("skyflow.service_account._utils.jwt.decode", side_effect=jwt.ExpiredSignatureError) + def test_is_expired_expired_signature_error(self, mock_decode): + token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") + self.assertTrue(is_expired(token)) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_token_uri_option_override(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + override_uri = "https://override-url.com" + options = {"token_uri": override_uri} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + get_service_account_token(creds, options, None) + mock_get_signed_jwt.assert_called_once() + call_args = mock_get_signed_jwt.call_args + self.assertEqual(call_args[0][3], override_uri) + + @patch("json.load", side_effect=json.JSONDecodeError("bad json", "", 0)) + def test_generate_signed_data_tokens_from_file_invalid_json(self, mock_load): + invalid_path = os.path.join(os.path.dirname(__file__), "invalid_creds.json") + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(invalid_path, {"data_tokens": ["t1"]}) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.FILE_INVALID_JSON.value.format(invalid_path), + ) diff --git a/tests/utils/test__helpers.py b/tests/utils/test__helpers.py index 6758b62e..6016c798 100644 --- a/tests/utils/test__helpers.py +++ b/tests/utils/test__helpers.py @@ -39,9 +39,10 @@ def test_format_scope_special_characters(self): def test_is_valid_url_valid(self): self.assertTrue(is_valid_url("https://example.com")) - self.assertTrue(is_valid_url("http://example.com/path")) + self.assertTrue(is_valid_url("https://example.com/path")) def test_is_valid_url_invalid(self): + self.assertFalse(is_valid_url("http://example.com")) self.assertFalse(is_valid_url("ftp://example.com")) self.assertFalse(is_valid_url("example.com")) self.assertFalse(is_valid_url("invalid-url")) diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index b0466498..1363ad7d 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1,40 +1,65 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock, PropertyMock import os -from unittest.mock import MagicMock from urllib.parse import quote import tempfile, json from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.generated.rest import ErrorResponse -from skyflow.service_account import generate_bearer_token, generate_signed_data_tokens, \ - generate_signed_data_tokens_from_creds, generate_bearer_token_from_creds -from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ - parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ - parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ - handle_exception, validate_api_key, encode_column_values, parse_deidentify_text_response, \ - parse_reidentify_text_response, convert_detected_entity_to_entity_info -from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error +from skyflow.service_account import ( + generate_bearer_token, + generate_signed_data_tokens, + generate_signed_data_tokens_from_creds, + generate_bearer_token_from_creds, +) +from skyflow.utils import ( + get_credentials, + SkyflowMessages, + get_vault_url, + construct_invoke_connection_request, + parse_insert_response, + parse_update_record_response, + parse_delete_response, + parse_get_response, + parse_detokenize_response, + parse_tokenize_response, + parse_query_response, + parse_invoke_connection_response, + handle_exception, + validate_api_key, + encode_column_values, + parse_deidentify_text_response, + parse_reidentify_text_response, + convert_detected_entity_to_entity_info, +) +from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error, r_urlencode from skyflow.utils.enums import EnvUrls, Env, ContentType from skyflow.vault.connection import InvokeConnectionResponse from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse from skyflow.vault.tokens import DetokenizeResponse, TokenizeResponse creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) TEST_ERROR_MESSAGE = "Test error message." VALID_ENV_CREDENTIALS = credentials -class TestUtils(unittest.TestCase): +class TestUtils(unittest.TestCase): @patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": json.dumps(VALID_ENV_CREDENTIALS)}) def test_get_credentials_env_variable(self): credentials = get_credentials() - credentials_string = credentials.get('credentials_string') - self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n')) + credentials_string = credentials.get("credentials_string") + self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace("\n", "\\n")) + + @patch("skyflow.utils._utils.dotenv.find_dotenv", return_value=None) + @patch.dict(os.environ, {}, clear=True) + def test_get_credentials_no_credentials_raises(self, mock_find_dotenv): + with self.assertRaises(SkyflowError) as context: + get_credentials(config_level_creds=None, common_skyflow_creds=None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) def test_get_credentials_with_config_level_creds(self): test_creds = {"authToken": "test_token"} @@ -60,11 +85,13 @@ def test_get_vault_url_with_invalid_cluster_id(self): valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id) + ) def test_get_vault_url_with_invalid_env(self): valid_cluster_id = "cluster_id" - valid_env =EnvUrls.DEV + valid_env = EnvUrls.DEV valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) @@ -79,7 +106,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): "http_code": 400, "http_status": "Bad Request", "grpc_code": 3, - "details": ["detail1"] + "details": ["detail1"], } } @@ -90,13 +117,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): handle_json_error(mock_error, error_dict, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Dict error message", - 400, - request_id, - "Bad Request", - 3, - ["detail1"], - logger=mock_logger + "Dict error message", 400, request_id, "Bad Request", 3, ["detail1"], logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -109,7 +130,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ "http_code": 403, "http_status": "Forbidden", "grpc_code": 7, - "details": ["detail2"] + "details": ["detail2"], } } @@ -120,13 +141,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ handle_json_error(mock_error, mock_error_response, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "ErrorResponse message", - 403, - request_id, - "Forbidden", - 7, - ["detail2"], - logger=mock_logger + "ErrorResponse message", 403, request_id, "Forbidden", 7, ["detail2"], logger=mock_logger ) def test_parse_path_params(self): @@ -140,13 +155,56 @@ def test_to_lowercase_keys(self): expected_output = {"key1": "value1", "key2": "value2"} self.assertEqual(to_lowercase_keys(input_dict), expected_output) + def test_r_urlencode_with_list_input(self): + pairs = {} + r_urlencode([], pairs, ["a", "b"]) + self.assertIn("[0]", pairs) + self.assertIn("[1]", pairs) + self.assertEqual(pairs["[0]"], "a") + self.assertEqual(pairs["[1]"], "b") + + def test_r_urlencode_with_tuple_input(self): + pairs = {} + r_urlencode([], pairs, ("x", "y")) + self.assertIn("[0]", pairs) + self.assertEqual(pairs["[0]"], "x") + def test_get_metrics(self): metrics = get_metrics() - self.assertIn('sdk_name_version', metrics) - self.assertIn('sdk_client_device_model', metrics) - self.assertIn('sdk_client_os_details', metrics) - self.assertIn('sdk_runtime_details', metrics) + self.assertIn("sdk_name_version", metrics) + self.assertIn("sdk_client_device_model", metrics) + self.assertIn("sdk_client_os_details", metrics) + self.assertIn("sdk_runtime_details", metrics) + def test_get_metrics_platform_node_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + with patch("skyflow.utils._utils.platform") as mock_platform: + mock_platform.node.side_effect = OSError("no node") + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_device_model"], "") + utils_module._CACHED_METRICS.clear() + + def test_get_metrics_sys_attribute_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + + class _RaisingSys: + @property + def platform(self): + raise RuntimeError("no platform") + + @property + def version(self): + raise RuntimeError("no version") + + with patch("skyflow.utils._utils.sys", _RaisingSys()): + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_os_details"], "") + self.assertIn("sdk_runtime_details", metrics) + utils_module._CACHED_METRICS.clear() def test_construct_invoke_connection_request_valid(self): mock_connection_request = Mock() @@ -166,7 +224,7 @@ def test_construct_invoke_connection_request_valid(self): self.assertEqual(result.url, expected_url) self.assertEqual(result.method, "POST") - self.assertEqual(result.headers['Content-Type'], ContentType.JSON.value) + self.assertEqual(result.headers["Content-Type"], ContentType.JSON.value) self.assertEqual(result.body, json.dumps(mock_connection_request.body)) @@ -232,9 +290,7 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} - mock_connection_request.body = { - "name": (None, "John Doe") - } + mock_connection_request.body = {"name": (None, "John Doe")} mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -244,13 +300,27 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): self.assertIsInstance(result, PreparedRequest) + def test_parse_insert_response_with_tokens_continue_on_error(self): + api_response = Mock() + api_response.headers = {"x-request-id": "req-1"} + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1", "tokens": {"col1": "tok1"}}]}}, + ] + ) + result = parse_insert_response(api_response, continue_on_error=True) + self.assertEqual(result.inserted_fields[0]["col1"], "tok1") + self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1") + def test_parse_insert_response(self): api_response = Mock() api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - api_response.data = Mock(responses=[ - {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, - {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}} - ]) + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, + {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}}, + ] + ) result = parse_insert_response(api_response, continue_on_error=True) self.assertEqual(len(result.inserted_fields), 1) self.assertEqual(len(result.errors), 1) @@ -264,17 +334,19 @@ def test_parse_insert_response(self): def test_parse_insert_response_continue_on_error_false(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), - Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}) - ]) + mock_api_response.data = Mock( + records=[ + Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), + Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}), + ] + ) result = parse_insert_response(mock_api_response, continue_on_error=False) self.assertIsInstance(result, InsertResponse) expected_inserted_fields = [ {"skyflow_id": "id_1", "token1": "token_value1"}, - {"skyflow_id": "id_2", "token2": "token_value2"} + {"skyflow_id": "id_2", "token2": "token_value2"}, ] self.assertEqual(result.inserted_fields, expected_inserted_fields) @@ -285,8 +357,8 @@ def test_parse_update_record_response(self): api_response.skyflow_id = "id1" api_response.tokens = {"token1": "value1"} result = parse_update_record_response(api_response) - self.assertEqual(result.updated_field['skyflow_id'], "id1") - self.assertEqual(result.updated_field['token1'], "value1") + self.assertEqual(result.updated_field["skyflow_id"], "id1") + self.assertEqual(result.updated_field["token1"], "value1") def test_parse_delete_response_successful(self): mock_api_response = Mock() @@ -304,42 +376,39 @@ def test_parse_delete_response_successful(self): def test_parse_get_response_successful(self): mock_api_response = Mock() mock_api_response.records = [ - Mock(fields={'field1': 'value1', 'field2': 'value2'}), - Mock(fields={'field1': 'value3', 'field2': 'value4'}) + Mock(fields={"field1": "value1", "field2": "value2"}), + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_get_response(mock_api_response) self.assertIsInstance(result, GetResponse) - expected_data = [ - {'field1': 'value1', 'field2': 'value2'}, - {'field1': 'value3', 'field2': 'value4'} - ] + expected_data = [{"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"}] self.assertEqual(result.data, expected_data) - # self.assertEqual(result.errors, None) + self.assertIsNone(result.errors) def test_parse_detokenize_response_with_mixed_records(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(token="token1", value="value1", value_type="Type1", error=None), - Mock(token="token2", value=None, value_type=None, error="Some error"), - Mock(token="token3", value="value3", value_type="Type2", error=None), - ]) + mock_api_response.data = Mock( + records=[ + Mock(token="token1", value="value1", value_type="Type1", error=None), + Mock(token="token2", value=None, value_type=None, error="Some error"), + Mock(token="token3", value="value3", value_type="Type2", error=None), + ] + ) result = parse_detokenize_response(mock_api_response) self.assertIsInstance(result, DetokenizeResponse) expected_detokenized_fields = [ {"token": "token1", "value": "value1", "type": "Type1"}, - {"token": "token3", "value": "value3", "type": "Type2"} + {"token": "token3", "value": "value3", "type": "Type2"}, ] - expected_errors = [ - {"token": "token2", "error": "Some error", "request_id": "12345"} - ] + expected_errors = [{"token": "token2", "error": "Some error", "request_id": "12345"}] self.assertEqual(result.detokenized_fields, expected_detokenized_fields) self.assertEqual(result.errors, expected_errors) @@ -355,11 +424,7 @@ def test_parse_tokenize_response_with_valid_records(self): result = parse_tokenize_response(mock_api_response) self.assertIsInstance(result, TokenizeResponse) - expected_tokenized_fields = [ - {"token": "token1"}, - {"token": "token2"}, - {"token": "token3"} - ] + expected_tokenized_fields = [{"token": "token1"}, {"token": "token2"}, {"token": "token3"}] self.assertEqual(result.tokenized_fields, expected_tokenized_fields) @@ -367,7 +432,7 @@ def test_parse_query_response_with_valid_records(self): mock_api_response = Mock() mock_api_response.records = [ Mock(fields={"field1": "value1", "field2": "value2"}), - Mock(fields={"field1": "value3", "field2": "value4"}) + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_query_response(mock_api_response) @@ -376,7 +441,7 @@ def test_parse_query_response_with_valid_records(self): expected_fields = [ {"field1": "value1", "field2": "value2", "tokenized_data": {}}, - {"field1": "value3", "field2": "value4", "tokenized_data": {}} + {"field1": "value3", "field2": "value4", "tokenized_data": {}}, ] self.assertEqual(result.fields, expected_fields) @@ -384,7 +449,7 @@ def test_parse_query_response_with_valid_records(self): @patch("requests.Response") def test_parse_invoke_connection_response_successful(self, mock_response): mock_response.status_code = 200 - mock_response.content = json.dumps({"key": "value"}).encode('utf-8') + mock_response.content = json.dumps({"key": "value"}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} result = parse_invoke_connection_response(mock_response) @@ -398,7 +463,7 @@ def test_parse_invoke_connection_response_successful(self, mock_response): def test_parse_invoke_connection_response_json_decode_error(self, mock_response): """Test that non-JSON content in successful response is returned as string.""" mock_response.status_code = 200 - mock_response.content = "Non-JSON Content".encode('utf-8') + mock_response.content = "Non-JSON Content".encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status = Mock() @@ -412,7 +477,7 @@ def test_parse_invoke_connection_response_json_decode_error(self, mock_response) @patch("requests.Response") def test_parse_invoke_connection_response_http_error_with_json_error_message(self, mock_response): mock_response.status_code = 404 - mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode('utf-8') + mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("404 Error") @@ -423,10 +488,38 @@ def test_parse_invoke_connection_response_http_error_with_json_error_message(sel self.assertEqual(context.exception.message, "Not Found") self.assertEqual(context.exception.request_id, "1234") + @patch("requests.Response") + def test_parse_invoke_connection_response_with_error_from_client_header(self, mock_response): + from requests.models import HTTPError + + mock_response.status_code = 400 + mock_response.content = json.dumps( + { + "error": { + "message": "Client error", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": 3, + "details": None, + } + } + ).encode("utf-8") + mock_response.headers = { + "x-request-id": "rid-1", + "error-from-client": "true", + } + mock_response.raise_for_status.side_effect = HTTPError("400") + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + err = context.exception + self.assertEqual(err.message, "Client error") + self.assertIsNotNone(err.details) + self.assertTrue(any(d.get("error_from_client") is True for d in err.details)) + @patch("requests.Response") def test_parse_invoke_connection_response_http_error_without_json_error_message(self, mock_response): mock_response.status_code = 500 - mock_response.content = "Internal Server Error".encode('utf-8') + mock_response.content = "Internal Server Error".encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("500 Error") @@ -434,7 +527,7 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, "Internal Server Error") + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) self.assertEqual(context.exception.http_code, 500) self.assertEqual(context.exception.request_id, "1234") @@ -442,31 +535,24 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( def test_handle_exception_json_error(self, mock_log_and_reject_error): mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'application/json' - } - mock_error.body = json.dumps({ - "error": { - "message": "JSON error occurred.", - "http_code": 400, - "http_status": "Bad Request", - "grpc_code": "8", - "details": "Detailed message" + mock_error.headers = {"x-request-id": "1234", "content-type": "application/json"} + mock_error.body = json.dumps( + { + "error": { + "message": "JSON error occurred.", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": "8", + "details": "Detailed message", + } } - }).encode('utf-8') + ).encode("utf-8") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "JSON error occurred.", - 400, - "1234", - "Bad Request", - "8", - "Detailed message", - logger=mock_logger + "JSON error occurred.", 400, "1234", "Bad Request", "8", "Detailed message", logger=mock_logger ) def test_validate_api_key_valid_key(self): @@ -502,12 +588,7 @@ def test_parse_deidentify_text_response(self): mock_entity.value = "sensitive_value" mock_entity.entity_type = "EMAIL" mock_entity.entity_scores = {"EMAIL": 0.95} - mock_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 - ) + mock_entity.location = Mock(start_index=10, end_index=20, start_index_processed=15, end_index_processed=25) mock_api_response = Mock() mock_api_response.processed_text = "Sample processed text" @@ -564,10 +645,7 @@ def test__convert_detected_entity_to_entity_info(self): mock_detected_entity.entity_type = "EMAIL" mock_detected_entity.entity_scores = {"EMAIL": 0.95} mock_detected_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 + start_index=10, end_index=20, start_index_processed=15, end_index_processed=25 ) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -588,12 +666,7 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): mock_detected_entity.value = None mock_detected_entity.entity_type = "UNKNOWN" mock_detected_entity.entity_scores = {} - mock_detected_entity.location = Mock( - start_index=0, - end_index=0, - start_index_processed=0, - end_index_processed=0 - ) + mock_detected_entity.location = Mock(start_index=0, end_index=0, start_index_processed=0, end_index_processed=0) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -606,21 +679,18 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) - @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_connect_error(self, mock_log_and_reject_error): """Test handling httpx.ConnectError.""" import httpx + mock_error = httpx.ConnectError("Connection refused") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - 'Connection refused', - SkyflowMessages.ErrorCodes.INVALID_INPUT.value, - None, - logger=mock_logger + "Connection refused", SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -632,10 +702,7 @@ def test_handle_exception_no_headers_attribute(self, mock_log_and_reject_error): handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Generic error", - SkyflowMessages.ErrorCodes.SERVER_ERROR.value, - None, - logger=mock_logger + "Generic error", SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -643,89 +710,67 @@ def test_handle_exception_no_body_attribute(self, mock_log_and_reject_error): """Test handling error without body attribute.""" mock_error = Mock() mock_error.headers = {"x-request-id": "12345"} - delattr(mock_error, 'body') + delattr(mock_error, "body") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once() - self.assertEqual( - mock_log_and_reject_error.call_args[0][1], - SkyflowMessages.ErrorCodes.SERVER_ERROR.value - ) + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_text_plain_error(self, mock_log_and_reject_error): """Test handling text/plain content type error.""" mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'text/plain' - } + mock_error.headers = {"x-request-id": "1234", "content-type": "text/plain"} mock_error.body = "Plain text error message" mock_error.status = 500 mock_logger = Mock() handle_exception(mock_error, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - "Plain text error message", - 500, - "1234", - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with("Plain text error message", 500, "1234", logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_error): """Test handling generic error with unknown content type.""" mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'application/xml' - } + mock_error.headers = {"x-request-id": "1234", "content-type": "application/xml"} mock_error.body = "XML error" mock_error.status = 503 mock_logger = Mock() handle_exception(mock_error, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - str(mock_error), - 503, - "1234", - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, "1234", logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_no_content_type(self, mock_log_and_reject_error): """Test handling error without content-type header.""" mock_error = Mock() - mock_error.headers = {'x-request-id': '1234'} + mock_error.headers = {"x-request-id": "1234"} mock_error.body = "Some error" mock_error.status = 500 mock_logger = Mock() handle_exception(mock_error, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - str(mock_error), - 500, - "1234", - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 500, "1234", logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): """Test handling JSON error when data is a JSON string.""" - error_json_string = json.dumps({ - "error": { - "message": "String JSON error", - "http_code": 422, - "http_status": "Unprocessable Entity", - "grpc_code": 3, - "details": ["validation failed"] + error_json_string = json.dumps( + { + "error": { + "message": "String JSON error", + "http_code": 422, + "http_status": "Unprocessable Entity", + "grpc_code": 3, + "details": ["validation failed"], + } } - }) + ) mock_error = Mock() mock_logger = Mock() @@ -734,13 +779,7 @@ def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): handle_json_error(mock_error, error_json_string, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "String JSON error", - 422, - request_id, - "Unprocessable Entity", - 3, - ["validation failed"], - logger=mock_logger + "String JSON error", 422, request_id, "Unprocessable Entity", 3, ["validation failed"], logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -756,17 +795,12 @@ def test_handle_json_error_with_invalid_json(self, mock_log_and_reject_error): # Should call with INVALID_JSON_RESPONSE error mock_log_and_reject_error.assert_called_once() - self.assertEqual( - mock_log_and_reject_error.call_args[0][0], - SkyflowMessages.Error.INVALID_JSON_RESPONSE.value - ) + self.assertEqual(mock_log_and_reject_error.call_args[0][0], SkyflowMessages.Error.INVALID_JSON_RESPONSE.value) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_json_error_missing_error_field(self, mock_log_and_reject_error): """Test handling JSON error with missing error field.""" - error_dict = { - "message": "Error without error wrapper" - } + error_dict = {"message": "Error without error wrapper"} mock_error = Mock() mock_logger = Mock() @@ -793,14 +827,10 @@ def test_handle_text_error_with_status(self, mock_log_and_reject_error): error_data = "Resource not found" from skyflow.utils._utils import handle_text_error + handle_text_error(mock_error, error_data, request_id, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - "Resource not found", - 404, - request_id, - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with("Resource not found", 404, request_id, logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_generic_error_with_status(self, mock_log_and_reject_error): @@ -811,14 +841,10 @@ def test_handle_generic_error_with_status(self, mock_log_and_reject_error): status = 503 from skyflow.utils._utils import handle_generic_error_with_status + handle_generic_error_with_status(mock_error, request_id, status, mock_logger) - mock_log_and_reject_error.assert_called_once_with( - str(mock_error), - 503, - request_id, - logger=mock_logger - ) + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, request_id, logger=mock_logger) @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_with_none_error(self, mock_log_and_reject_error): @@ -831,10 +857,10 @@ def test_handle_exception_with_none_error(self, mock_log_and_reject_error): SkyflowMessages.Error.GENERIC_API_ERROR.value, SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, - logger=mock_logger + logger=mock_logger, ) - #failed + # failed @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_error): """Test handling empty string error.""" @@ -847,22 +873,54 @@ def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_erro mock_log_and_reject_error.assert_called_once() # Should use str(error) or default message - self.assertEqual( - mock_log_and_reject_error.call_args[0][1], - SkyflowMessages.ErrorCodes.SERVER_ERROR.value - ) + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) @patch("skyflow.utils._utils.log_and_reject_error") - def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): - """Test handling JSON error when data is bytes.""" + def test_handle_json_error_with_responses_key(self, mock_log_and_reject_error): + """Test handle_json_error when body has 'responses' key (batch/continue_on_error path).""" error_dict = { - "error": { - "message": "Bytes error", - "http_code": 401, - "http_status": "Unauthorized" - } + "responses": [ + {"Status": 400, "Body": {"error": "record not found"}}, + {"Status": 400, "Body": {"error": "invalid field"}}, + ] } - error_bytes = json.dumps(error_dict).encode('utf-8') + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-responses" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertIn("record not found", args[0]) + self.assertIn("invalid field", args[0]) + self.assertEqual(args[1], 400) + self.assertIsNone(args[3]) # http_status + self.assertIsNone(args[4]) # grpc_code + self.assertEqual(args[5], []) # details + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_responses_no_error_messages(self, mock_log_and_reject_error): + """Test handle_json_error with responses key but no error body — falls back to default message.""" + error_dict = { + "responses": [ + {"Status": 200, "Body": {"records": [{"skyflow_id": "abc"}]}}, + ] + } + mock_error = Mock() + request_id = "test-request-id-responses-empty" + + handle_json_error(mock_error, error_dict, request_id, None) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): + """Test handling JSON error when data is bytes.""" + error_dict = {"error": {"message": "Bytes error", "http_code": 401, "http_status": "Unauthorized"}} + error_bytes = json.dumps(error_dict).encode("utf-8") mock_error = Mock() mock_logger = Mock() @@ -871,13 +929,7 @@ def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): handle_json_error(mock_error, error_bytes, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Bytes error", - 401, - request_id, - "Unauthorized", - None, - [], - logger=mock_logger + "Bytes error", 401, request_id, "Unauthorized", None, [], logger=mock_logger ) # Add these new test methods to the TestUtils class: @@ -897,7 +949,7 @@ def test_construct_invoke_connection_request_with_no_headers(self): self.assertIsInstance(result, PreparedRequest) # Headers should be None when not provided - self.assertIsNone(result.headers.get('Content-Type')) + self.assertIsNone(result.headers.get("Content-Type")) def test_construct_invoke_connection_request_with_xml_content_type(self): """Test construct_invoke_connection_request with XML content type.""" @@ -913,10 +965,10 @@ def test_construct_invoke_connection_request_with_xml_content_type(self): result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) self.assertIsInstance(result, PreparedRequest) - self.assertEqual(result.headers['content-type'], 'application/xml') + self.assertEqual(result.headers["content-type"], "application/xml") # Body should be converted to XML - self.assertIn('', result.body) - self.assertIn('value', result.body) + self.assertIn("", result.body) + self.assertIn("value", result.body) def test_construct_invoke_connection_request_with_html_content_type(self): """Test construct_invoke_connection_request with HTML content type.""" @@ -932,7 +984,7 @@ def test_construct_invoke_connection_request_with_html_content_type(self): result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) self.assertIsInstance(result, PreparedRequest) - self.assertEqual(result.headers['content-type'], 'text/html') + self.assertEqual(result.headers["content-type"], "text/html") # Body should be JSON string for HTML self.assertEqual(result.body, json.dumps({"message": "Hello"})) @@ -951,8 +1003,8 @@ def test_construct_invoke_connection_request_multipart_removes_content_type(self self.assertIsInstance(result, PreparedRequest) # Content-Type should be auto-generated by requests library - self.assertIn('multipart/form-data', result.headers.get('Content-Type', '')) - self.assertIn('boundary=', result.headers.get('Content-Type', '')) + self.assertIn("multipart/form-data", result.headers.get("Content-Type", "")) + self.assertIn("boundary=", result.headers.get("Content-Type", "")) def test_construct_invoke_connection_request_with_no_body(self): """Test construct_invoke_connection_request when body is None.""" @@ -1119,10 +1171,7 @@ def test_parse_invoke_connection_response_xml_content(self, mock_response): """Test parsing XML response content.""" mock_response.status_code = 200 mock_response.content = b"success" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/xml" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/xml"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1137,10 +1186,7 @@ def test_parse_invoke_connection_response_url_encoded_content(self, mock_respons """Test parsing URL encoded response content.""" mock_response.status_code = 200 mock_response.content = b"card_number=4111111111111111&cvv=123" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/x-www-form-urlencoded" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/x-www-form-urlencoded"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1155,10 +1201,7 @@ def test_parse_invoke_connection_response_html_content(self, mock_response): """Test parsing HTML response content.""" mock_response.status_code = 200 mock_response.content = b"Success" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "text/html" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1173,17 +1216,14 @@ def test_parse_invoke_connection_response_html_error(self, mock_response): """Test parsing HTML error response.""" html_error = "

Error 500

" mock_response.status_code = 500 - mock_response.content = html_error.encode('utf-8') - mock_response.headers = { - "x-request-id": "1234", - "content-type": "text/html" - } + mock_response.content = html_error.encode("utf-8") + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} mock_response.raise_for_status = Mock(side_effect=HTTPError("500 Error")) with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, html_error) + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) self.assertEqual(context.exception.http_code, 500) self.assertEqual(context.exception.request_id, "1234") @@ -1192,10 +1232,7 @@ def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, """Test that JSON decode error falls back to returning string content.""" mock_response.status_code = 200 mock_response.content = b"Not valid JSON but still success" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/json" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/json"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1209,7 +1246,7 @@ def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, def test_parse_invoke_connection_response_no_content_type_with_json(self, mock_response): """Test parsing response with no content-type but valid JSON.""" mock_response.status_code = 200 - mock_response.content = json.dumps({"success": True}).encode('utf-8') + mock_response.content = json.dumps({"success": True}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status = Mock() @@ -1240,10 +1277,7 @@ def test_parse_invoke_connection_response_bytes_content(self, mock_response): """Test parsing response with bytes content.""" mock_response.status_code = 200 mock_response.content = b"Binary data response" - mock_response.headers = { - "x-request-id": "1234", - "content-type": "application/octet-stream" - } + mock_response.headers = {"x-request-id": "1234", "content-type": "application/octet-stream"} mock_response.raise_for_status = Mock() result = parse_invoke_connection_response(mock_response) @@ -1269,7 +1303,7 @@ def __repr__(self): connection_url = "https://example.com/endpoint" - with patch('json.dumps', side_effect=TypeError("Object is not JSON serializable")): + with patch("json.dumps", side_effect=TypeError("Object is not JSON serializable")): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) @@ -1287,7 +1321,7 @@ def test_construct_invoke_connection_request_headers_generic_exception(self): connection_url = "https://example.com/endpoint" - with patch('skyflow.utils._utils.to_lowercase_keys', side_effect=Exception("Generic error")): + with patch("skyflow.utils._utils.to_lowercase_keys", side_effect=Exception("Generic error")): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) @@ -1305,7 +1339,7 @@ def test_construct_invoke_connection_request_body_processing_exception(self): connection_url = "https://example.com/endpoint" - with patch('skyflow.utils._utils.get_data_from_content_type', side_effect=Exception("Body processing error")): + with patch("skyflow.utils._utils.get_data_from_content_type", side_effect=Exception("Body processing error")): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) @@ -1344,7 +1378,7 @@ def test_construct_invoke_connection_request_invalid_url_exception(self): connection_url = "https://example.com/endpoint" - with patch('requests.Request') as mock_request_class: + with patch("requests.Request") as mock_request_class: mock_request_instance = Mock() mock_request_instance.prepare.side_effect = Exception("Invalid URL structure") mock_request_class.return_value = mock_request_instance @@ -1352,10 +1386,7 @@ def test_construct_invoke_connection_request_invalid_url_exception(self): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) - self.assertEqual( - context.exception.message, - SkyflowMessages.Error.INVALID_URL.value.format(connection_url) - ) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) def test_construct_invoke_connection_request_prepare_exception(self): @@ -1369,7 +1400,7 @@ def test_construct_invoke_connection_request_prepare_exception(self): connection_url = "https://example.com/endpoint" - with patch('requests.Request') as mock_request_class: + with patch("requests.Request") as mock_request_class: mock_request_instance = Mock() mock_request_instance.prepare.side_effect = Exception("Prepare failed") mock_request_class.return_value = mock_request_instance @@ -1377,10 +1408,7 @@ def test_construct_invoke_connection_request_prepare_exception(self): with self.assertRaises(SkyflowError) as context: construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) - self.assertEqual( - context.exception.message, - SkyflowMessages.Error.INVALID_URL.value.format(connection_url) - ) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) def test_construct_invoke_connection_request_body_not_dict_raises_error(self): @@ -1400,7 +1428,7 @@ def test_construct_invoke_connection_request_body_not_dict_raises_error(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - @patch('skyflow.utils._utils.validate_invoke_connection_params') + @patch("skyflow.utils._utils.validate_invoke_connection_params") def test_construct_invoke_connection_request_validation_exception(self, mock_validate): """Test that validation exceptions are properly propagated.""" mock_connection_request = Mock() @@ -1419,15 +1447,16 @@ def test_construct_invoke_connection_request_validation_exception(self, mock_val self.assertEqual(context.exception.message, "Validation failed") self.assertEqual(context.exception.http_code, 400) + def test_generate_bearer_token_invalid_token_uri_type(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 # invalid type + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": 12345, # invalid type } - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1435,13 +1464,8 @@ def test_generate_bearer_token_invalid_token_uri_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_bearer_token_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1450,13 +1474,13 @@ def test_generate_bearer_token_invalid_token_uri_url(self): def test_generate_bearer_token_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"token_uri": "https://another-valid-url.com"} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() # Patch AuthClient and jwt.encode to avoid real HTTP and signing @@ -1464,32 +1488,22 @@ def test_generate_bearer_token_options_override_token_uri(self): mock_get_signed_jwt.return_value = "signed" with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value - mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), - {"access_token": "token", - "token_type": "bearer"}) + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) generate_bearer_token(tmp.name, options) args, kwargs = mock_get_signed_jwt.call_args self.assertEqual(args[3], options["token_uri"]) def test_generate_bearer_token_from_creds_invalid_token_uri_type(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: generate_bearer_token_from_creds(creds_str) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: generate_bearer_token_from_creds(creds_str) @@ -1497,10 +1511,10 @@ def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): def test_generate_bearer_token_from_creds_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"token_uri": "https://another-valid-url.com"} creds_str = json.dumps(creds) @@ -1508,22 +1522,17 @@ def test_generate_bearer_token_from_creds_options_override_token_uri(self): mock_get_signed_jwt.return_value = "signed" with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value - mock_auth_api.authentication_service_get_auth_token.return_value = type("obj", (), - {"access_token": "token", - "token_type": "bearer"}) + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) generate_bearer_token_from_creds(creds_str, options) args, kwargs = mock_get_signed_jwt.call_args self.assertEqual(args[3], options["token_uri"]) def test_generate_signed_data_tokens_invalid_token_uri_type(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} options = {"data_tokens": ["token1"]} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1531,14 +1540,9 @@ def test_generate_signed_data_tokens_invalid_token_uri_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_signed_data_tokens_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} options = {"data_tokens": ["token1"]} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with self.assertRaises(SkyflowError) as context: @@ -1546,12 +1550,7 @@ def test_generate_signed_data_tokens_invalid_token_uri_url(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 12345 - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} options = {"data_tokens": ["token1"]} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: @@ -1559,12 +1558,7 @@ def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): - creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'not_a_url' - } + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} options = {"data_tokens": ["token1"]} creds_str = json.dumps(creds) with self.assertRaises(SkyflowError) as context: @@ -1573,34 +1567,36 @@ def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): def test_generate_signed_data_tokens_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: json.dump(creds, tmp) tmp.flush() with patch("jwt.encode") as mock_jwt_encode: mock_jwt_encode.return_value = "signed" result = generate_signed_data_tokens(tmp.name, options) - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], "token1") - self.assertEqual(result[1], "signed_token_signed") + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self): creds = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id', - 'tokenURI': 'https://valid-url.com' + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", } options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} creds_str = json.dumps(creds) with patch("jwt.encode") as mock_jwt_encode: mock_jwt_encode.return_value = "signed" result = generate_signed_data_tokens_from_creds(creds_str, options) - self.assertIsInstance(result, tuple) - self.assertEqual(result[0], "token1") - self.assertEqual(result[1], "signed_token_signed") + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 8de9b219..ec4d5bec 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -12,13 +12,15 @@ validate_insert_request, validate_delete_request, validate_query_request, validate_get_detect_run_request, validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request, validate_invoke_connection_params, - validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request + validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request, + validate_file_upload_request ) from skyflow.utils import SkyflowMessages from skyflow.utils.enums import DetectEntities, RedactionType from skyflow.vault.data import GetRequest, UpdateRequest from skyflow.vault.detect import DeidentifyTextRequest, Transformations, DateTransformation, ReidentifyTextRequest, \ - FileInput, DeidentifyFileRequest + FileInput, DeidentifyFileRequest, Bleep +from skyflow.vault.data._file_upload_request import FileUploadRequest from skyflow.vault.tokens import DetokenizeRequest from skyflow.vault.connection._invoke_connection_request import InvokeConnectionRequest @@ -217,6 +219,18 @@ def test_validate_update_vault_config_invalid_cluster_id(self): validate_update_vault_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format("vault123")) + def test_validate_update_vault_config_missing_credentials(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123") + ) + def test_validate_connection_config_valid(self): config = { "connection_id": "conn123", @@ -250,6 +264,18 @@ def test_validate_connection_config_empty_connection_id(self): validate_connection_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value) + def test_validate_connection_config_missing_credentials(self): + config = { + "connection_id": "conn123", + "connection_url": "https://example.com", + } + with self.assertRaises(SkyflowError) as context: + validate_connection_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", "conn123") + ) + def test_validate_update_connection_config_valid(self): config = { "connection_id": "conn123", @@ -1163,3 +1189,279 @@ def test_validate_update_vault_config_with_invalid_token_uri_url(self): with self.assertRaises(SkyflowError) as context: validate_update_vault_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + # --- validate_file_from_request --- + + def test_validate_file_from_request_none_input(self): + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_INPUT.value) + + def test_validate_file_from_request_file_without_name_attr(self): + file_obj = MagicMock(spec=[]) # no attributes at all + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_file_with_empty_name(self): + file_obj = MagicMock() + file_obj.name = " " # whitespace-only name + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_extension_only_name(self): + file_obj = MagicMock() + # A trailing-slash path gives os.path.basename() == "", so splitext returns ("", "") + file_obj.name = "/some/directory/" + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_from_request_empty_string_file_path(self): + file_input = MagicMock() + file_input.file = None + file_input.file_path = "" # empty string — has_file_path=True, so goes to elif branch + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) + + # --- validate_deidentify_file_request bleep sub-fields --- + + def test_validate_deidentify_file_request_invalid_bleep_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, bleep="not_a_bleep") + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_TYPE.value) + + def test_validate_deidentify_file_request_invalid_bleep_gain(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(gain="loud") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_GAIN.value) + + def test_validate_deidentify_file_request_invalid_bleep_frequency(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(frequency="high") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value) + + def test_validate_deidentify_file_request_invalid_bleep_start_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(start_padding="early") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value) + + def test_validate_deidentify_file_request_invalid_bleep_stop_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(stop_padding="late") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value) + + # --- validate_deidentify_file_request output_directory --- + + def test_validate_deidentify_file_request_invalid_output_directory_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=123) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value) + + def test_validate_deidentify_file_request_output_directory_not_found(self): + file_input = FileInput(file_path=self.temp_file_path) + nonexistent = "/tmp/skyflow_nonexistent_dir_12345" + request = DeidentifyFileRequest(file=file_input, output_directory=nonexistent) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(nonexistent) + ) + + def test_validate_deidentify_file_request_valid_output_directory(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=self.temp_dir_path) + validate_deidentify_file_request(self.logger, request) + + # --- validate_file_upload_request --- + + def test_validate_file_upload_request_none(self): + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_table(self): + request = MagicMock() + request.table = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_empty_table(self): + request = MagicMock() + request.table = " " + request.column_name = "file_col" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(None)) + ) + + def test_validate_file_upload_request_empty_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = "" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type("")) + ) + + def test_validate_file_upload_request_empty_skyflow_id(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + skyflow_id=" ", + file_path=self.temp_file_path + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD") + ) + + def test_validate_file_upload_request_invalid_file_object_seek(self): + file_obj = MagicMock() + file_obj.seek.side_effect = OSError("seek failed") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=file_obj + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_OBJECT.value) + + def test_validate_file_upload_request_valid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path=self.temp_file_path + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_invalid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path="/nonexistent/path/file.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_PATH.value) + + def test_validate_file_upload_request_valid_base64(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded, + file_name="sample.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_base64_without_file_name(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_upload_request_invalid_base64_string(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64="not-valid-base64!!!", + file_name="sample.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BASE64_STRING.value) + + def test_validate_file_upload_request_valid_file_object(self): + with open(self.temp_file_path, "rb") as f: + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=f + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_missing_file_source(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + # --- validate_deidentify_text_request transformations --- + + def test_validate_deidentify_text_request_invalid_transformations(self): + request = DeidentifyTextRequest( + text="test text", + transformations="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value) + + # --- validate_reidentify_text_request masked_entities --- + + def test_validate_reidentify_text_request_invalid_masked_entities(self): + request = ReidentifyTextRequest( + text="test text", + masked_entities="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value) diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 6fa31e67..75826128 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -15,11 +15,19 @@ } CREDENTIALS_WITH_API_KEY = {"api_key": "dummy_api_key"} +CREDENTIALS_WITH_TOKEN = {"token": "dummy_static_token"} +CREDENTIALS_WITH_PATH = {"path": "/some/path/credentials.json"} +CREDENTIALS_WITH_STRING = {"credentials_string": '{"clientID": "x"}'} + class TestVaultClient(unittest.TestCase): def setUp(self): self.vault_client = VaultClient(CONFIG) + # ------------------------------------------------------------------ # + # Basic setters / getters # + # ------------------------------------------------------------------ # + def test_set_common_skyflow_credentials(self): credentials = {"api_key": "dummy_api_key"} self.vault_client.set_common_skyflow_credentials(credentials) @@ -31,173 +39,289 @@ def test_set_logger(self): self.assertEqual(self.vault_client.get_log_level(), "INFO") self.assertEqual(self.vault_client.get_logger(), mock_logger) + def test_get_vault_id(self): + self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + + def test_get_config(self): + self.assertEqual(self.vault_client.get_config(), CONFIG) + + def test_get_common_skyflow_credentials(self): + credentials = {"api_key": "dummy_api_key"} + self.vault_client.set_common_skyflow_credentials(credentials) + self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + + def test_get_log_level(self): + self.vault_client.set_logger("DEBUG", MagicMock()) + self.assertEqual(self.vault_client.get_log_level(), "DEBUG") + + def test_get_logger(self): + mock_logger = MagicMock() + self.vault_client.set_logger("INFO", mock_logger) + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + # ------------------------------------------------------------------ # + # initialize_client_configuration — first call (slow path) # + # ------------------------------------------------------------------ # + @patch("skyflow.vault.client.client.get_credentials") @patch("skyflow.vault.client.client.get_vault_url") @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") - def test_initialize_client_configuration(self, mock_init_api_client, mock_get_vault_url, mock_get_credentials): - mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY) + def test_initialize_client_configuration_first_call( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY mock_get_vault_url.return_value = "https://test-vault-url.com" self.vault_client.initialize_client_configuration() - mock_get_credentials.assert_called_once_with(CONFIG["credentials"], None, logger=None) - mock_get_vault_url.assert_called_once_with(CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None) + mock_get_credentials.assert_called_once_with( + CONFIG["credentials"], None, logger=None + ) + mock_get_vault_url.assert_called_once_with( + CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None + ) mock_init_api_client.assert_called_once() - @patch("skyflow.vault.client.client.Skyflow") - def test_initialize_api_client(self, mock_api_client): - self.vault_client.initialize_api_client("https://test-vault-url.com", "dummy_token") - mock_api_client.assert_called_once_with(base_url="https://test-vault-url.com", token="dummy_token") + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (static token) # + # ------------------------------------------------------------------ # - def test_get_records_api(self): - self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.records = MagicMock() - records_api = self.vault_client.get_records_api() - self.assertIsNotNone(records_api) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_api_key( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with api_key, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + # Side-effect simulates initialize_api_client actually setting __api_client + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) - def test_get_tokens_api(self): + self.vault_client.initialize_client_configuration() # first call — slow path + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() # second call — fast path + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_static_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with a static token, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_TOKEN + mock_get_vault_url.return_value = "https://test-vault-url.com" + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (service account) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_valid_sa_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, mock_is_expired + ): + """Service account with a still-valid token skips get_bearer_token entirely.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Seed the cached bearer token as if first call already ran self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.tokens = MagicMock() - tokens_api = self.vault_client.get_tokens_api() - self.assertIsNotNone(tokens_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "cached_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_query_api(self): + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — token expiry (no client reinit) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_sa_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_expired_token_no_reinit( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, + mock_is_expired, mock_generate_bearer_token + ): + """Expired service account token is regenerated in-place; httpx client is NOT recreated.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Client already initialized — simulate warm state with an expired token self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.query = MagicMock() - query_api = self.vault_client.get_query_api() - self.assertIsNotNone(query_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "expired_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_vault_id(self): - self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + self.vault_client.initialize_client_configuration() - @patch("skyflow.vault.client.client.generate_bearer_token") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token, - mock_generate_bearer_token_from_creds): - token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) - self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"]) - - def test_update_config(self): - new_config = {"credentials": "new_credentials"} - self.vault_client.update_config(new_config) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) - self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") + # Token was regenerated + mock_generate_bearer_token.assert_called_once() + self.assertEqual( + self.vault_client._VaultClient__bearer_token, "new_sa_token" + ) + # httpx client was NOT recreated + mock_init_api_client.assert_not_called() - def test_get_config(self): - self.assertEqual(self.vault_client.get_config(), CONFIG) + # ------------------------------------------------------------------ # + # initialize_client_configuration — config update forces reinit # + # ------------------------------------------------------------------ # - def test_get_common_skyflow_credentials(self): - credentials = {"api_key": "dummy_api_key"} - self.vault_client.set_common_skyflow_credentials(credentials) - self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_reinit_after_update_config( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """update_config() marks the client stale; next call must recreate it.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" - def test_get_log_level(self): - log_level = "DEBUG" - self.vault_client.set_logger(log_level, MagicMock()) - self.assertEqual(self.vault_client.get_log_level(), log_level) + # Simulate already-initialized client + self.vault_client._VaultClient__api_client = MagicMock() + self.vault_client._VaultClient__is_static_token = True - def test_get_logger(self): - mock_logger = MagicMock() - self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) + self.vault_client.update_config({"cluster_id": "new_cluster"}) + self.vault_client.initialize_client_configuration() - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_expired_token_raises_error(self, mock_generate_bearer_token, mock_is_expired): - """Test that expired token raises SkyflowError.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.return_value = ("expired_token", None) - mock_is_expired.return_value = True + mock_get_credentials.assert_called_once() + mock_get_vault_url.assert_called_once() + mock_init_api_client.assert_called_once() - with self.assertRaises(SkyflowError) as context: - self.vault_client.get_bearer_token(credentials) + # ------------------------------------------------------------------ # + # initialize_api_client — lambda token provider # + # ------------------------------------------------------------------ # - self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) - self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_passes_callable_token(self, mock_skyflow): + """initialize_api_client must pass a callable (lambda) as token, not a string.""" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - def test_get_bearer_token_expired_token_from_creds_string_raises_error(self, mock_generate_bearer_token_from_creds, mock_is_expired): - """Test that expired token from credentials string raises SkyflowError.""" - credentials = {"credentials_string": '{"key": "value"}'} - mock_generate_bearer_token_from_creds.return_value = ("expired_token", None) - mock_is_expired.return_value = True + args, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["base_url"], "https://test-vault-url.com") + self.assertTrue(callable(kwargs["token"]), "token must be a callable (lambda)") - with self.assertRaises(SkyflowError) as context: - self.vault_client.get_bearer_token(credentials) + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_returns_cached_bearer_token(self, mock_skyflow): + """Lambda returns __bearer_token when it is set (interceptor behaviour).""" + self.vault_client._VaultClient__bearer_token = "refreshed_token" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") - self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) - self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - self.assertTrue(self.vault_client._VaultClient__is_config_updated) + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "refreshed_token") - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_reuses_valid_token(self, mock_generate_bearer_token, mock_is_expired): - """Test that valid bearer token is reused.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.return_value = ("valid_token", None) - mock_is_expired.return_value = False - - token1 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token1, "valid_token") - mock_generate_bearer_token.assert_called_once() + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_falls_back_to_initial_token(self, mock_skyflow): + """Lambda falls back to the initial token when __bearer_token is None.""" + self.vault_client._VaultClient__bearer_token = None + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "initial_token") + + # ------------------------------------------------------------------ # + # get_bearer_token # + # ------------------------------------------------------------------ # + + def test_get_bearer_token_with_api_key(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) + self.assertEqual(result, "dummy_api_key") + + def test_get_bearer_token_with_static_token(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_TOKEN) + self.assertEqual(result, "dummy_static_token") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("sa_token", None)) + def test_get_bearer_token_generates_from_path_on_first_call(self, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token") + self.assertEqual(self.vault_client._VaultClient__bearer_token, "sa_token") + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds", return_value=("sa_token_str", None)) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_generates_from_credentials_string(self, mock_log, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_STRING) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token_str") - token2 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token2, "valid_token") - mock_generate_bearer_token.assert_called_once() + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_regenerates_on_expiry(self, mock_log, mock_is_expired, mock_generate): + """Expired token is regenerated silently — no exception raised.""" + self.vault_client._VaultClient__bearer_token = "expired_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "new_token") - @patch("skyflow.vault.client.client.is_expired") @patch("skyflow.vault.client.client.generate_bearer_token") - def test_get_bearer_token_regenerates_after_config_update(self, mock_generate_bearer_token, mock_is_expired): - """Test that bearer token is regenerated after config update.""" - credentials = {"path": "/path/to/credentials.json"} - mock_generate_bearer_token.side_effect = [("first_token", None), ("second_token", None)] - mock_is_expired.return_value = False + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_reuses_valid_cached_token(self, mock_log, mock_is_expired, mock_generate): + """Valid cached token is reused without calling generate_bearer_token.""" + self.vault_client._VaultClient__bearer_token = "valid_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_not_called() + self.assertEqual(result, "valid_token") + + # ------------------------------------------------------------------ # + # update_config # + # ------------------------------------------------------------------ # + + def test_update_config_sets_flag(self): + self.vault_client.update_config({"credentials": "new_credentials"}) + self.assertTrue(self.vault_client._VaultClient__is_config_updated) + self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") + + # ------------------------------------------------------------------ # + # API accessor stubs # + # ------------------------------------------------------------------ # + + def test_get_records_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_records_api()) - token1 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token1, "first_token") + def test_get_tokens_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_tokens_api()) - self.vault_client.update_config({"new_key": "new_value"}) + def test_get_query_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_query_api()) - token2 = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token2, "second_token") - self.assertEqual(mock_generate_bearer_token.call_count, 2) - @patch("skyflow.vault.client.client.is_expired") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") - @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_credentials_string(self, mock_log_info, mock_generate_bearer_token_from_creds, mock_is_expired): - """Test get_bearer_token with credentials_string.""" - credentials = {"credentials_string": '{"clientID": "test", "clientName": "test"}'} - mock_generate_bearer_token_from_creds.return_value = ("token_from_creds", None) - mock_is_expired.return_value = False - - token = self.vault_client.get_bearer_token(credentials) - - self.assertEqual(token, "token_from_creds") - mock_generate_bearer_token_from_creds.assert_called_once() - mock_log_info.assert_called_with( - SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, - None - ) - def test_get_bearer_token_with_token(self): - credentials = {"token": "dummy_token"} - token = self.vault_client.get_bearer_token(credentials) - self.assertEqual(token, "dummy_token") - - def test_get_bearer_token_with_token_uri_in_credentials(self): - credentials = { - "path": "dummy_path", - "token_uri": "https://valid-url.com" - } - with patch("skyflow.vault.client.client.generate_bearer_token") as mock_generate_bearer_token, \ - patch("skyflow.vault.client.client.is_expired", return_value=False): - mock_generate_bearer_token.return_value = ("bearer_token", "bearer") - token = self.vault_client.get_bearer_token(credentials) - mock_generate_bearer_token.assert_called_once() - args, kwargs = mock_generate_bearer_token.call_args - self.assertIn("token_uri", args[1]) - self.assertEqual(args[1]["token_uri"], "https://valid-url.com") - self.assertEqual(token, "bearer_token") +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 35a13716..f073264c 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -1,9 +1,11 @@ +import json import unittest from unittest.mock import Mock, patch, MagicMock import requests from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages, parse_invoke_connection_response -from skyflow.utils.enums import RequestMethod +from skyflow.utils._utils import get_data_from_content_type, construct_invoke_connection_request +from skyflow.utils.enums import RequestMethod, ContentType from skyflow.utils._version import SDK_VERSION from skyflow.vault.connection import InvokeConnectionRequest from skyflow.vault.controller import Connection @@ -146,8 +148,9 @@ def test_invoke_request_error(self, mock_send, mock_get_credentials): with self.assertRaises(SkyflowError) as context: self.connection.invoke(request) - - self.assertEqual(context.exception.message, ERROR_RESPONSE_CONTENT) + + expected_message = SkyflowMessages.Error.API_ERROR.value.format(FAILURE_STATUS_CODE) + self.assertEqual(context.exception.message, expected_message) self.assertEqual(context.exception.http_code, FAILURE_STATUS_CODE) self.assertEqual(context.exception.request_id, "test-request-id") @@ -290,5 +293,383 @@ def test_invoke_construct_request_called(self, mock_construct, mock_get_credenti ) +class TestGetDataFromContentType(unittest.TestCase): + """Tests for get_data_from_content_type covering all supported content types.""" + + DATA = {'key': 'value', 'num': 42} + + # ── JSON ────────────────────────────────────────────────────────────────── + def test_json_content_type_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.JSON.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + # ── URL-encoded ─────────────────────────────────────────────────────────── + def test_urlencoded_content_type_returns_encoded_string(self): + data, files = get_data_from_content_type({'k': 'v'}, ContentType.URLENCODED.value) + self.assertIn('k=v', data) + self.assertEqual(files, {}) + + def test_urlencoded_nested_dict(self): + payload = {'a': {'b': 'c'}} + data, files = get_data_from_content_type(payload, ContentType.URLENCODED.value) + self.assertIsInstance(data, str) + self.assertIn('c', data) + self.assertEqual(files, {}) + + # ── Form-data ───────────────────────────────────────────────────────────── + def test_formdata_content_type_returns_files_dict(self): + data, files = get_data_from_content_type({'f1': 'v1', 'f2': 'v2'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertEqual(files, {'f1': (None, 'v1'), 'f2': (None, 'v2')}) + + def test_formdata_converts_values_to_str(self): + data, files = get_data_from_content_type({'num': 99}, ContentType.FORMDATA.value) + self.assertEqual(files['num'], (None, '99')) + + def test_formdata_single_key(self): + data, files = get_data_from_content_type({'only': 'one'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertIn('only', files) + + # ── XML ─────────────────────────────────────────────────────────────────── + def test_xml_text_xml_content_type_wraps_in_root(self): + data, files = get_data_from_content_type({'key': 'value'}, 'text/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertIn('', data) + self.assertEqual(files, {}) + + def test_xml_application_xml_content_type(self): + data, files = get_data_from_content_type({'key': 'value'}, 'application/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_content_type_enum_value(self): + data, files = get_data_from_content_type({'key': 'value'}, ContentType.XML.value) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw_string', 'text/xml') + self.assertEqual(data, 'raw_string') + self.assertEqual(files, {}) + + # ── HTML ────────────────────────────────────────────────────────────────── + def test_html_content_type_dict_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.HTML.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_text_html_content_type(self): + data, files = get_data_from_content_type(self.DATA, 'text/html') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw', ContentType.HTML.value) + self.assertEqual(data, 'raw') + self.assertEqual(files, {}) + + # ── None / unknown ──────────────────────────────────────────────────────── + def test_none_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, None) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, 'application/octet-stream') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_non_dict_returns_str(self): + data, files = get_data_from_content_type('blob', 'application/octet-stream') + self.assertEqual(data, 'blob') + self.assertEqual(files, {}) + + +class TestParseInvokeConnectionResponse(unittest.TestCase): + """Tests for parse_invoke_connection_response covering all success and error paths.""" + + def _make_response(self, status_code, content, headers=None, raise_http_error=False): + mock_resp = Mock(spec=requests.Response) + mock_resp.status_code = status_code + if isinstance(content, str): + mock_resp.content = content.encode('utf-8') + else: + mock_resp.content = content + mock_resp.headers = headers or {} + if raise_http_error: + mock_resp.raise_for_status.side_effect = requests.HTTPError() + else: + mock_resp.raise_for_status.return_value = None + return mock_resp + + # ── Success paths ───────────────────────────────────────────────────────── + def test_success_json_content_type_parses_body(self): + resp = self._make_response( + 200, + '{"result": "ok"}', + {'content-type': 'application/json', 'x-request-id': 'req-1'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'result': 'ok'}) + self.assertEqual(result.metadata.get('request_id'), 'req-1') + self.assertIsNone(result.errors) + + def test_success_plain_text_content_type_returns_string(self): + resp = self._make_response( + 200, + 'plain text response', + {'content-type': 'text/plain'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'plain text response') + + def test_success_no_content_type_tries_json_parse(self): + resp = self._make_response(200, '{"a": 1}', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'a': 1}) + + def test_success_no_content_type_invalid_json_returns_string(self): + resp = self._make_response(200, 'not json', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not json') + + def test_success_no_x_request_id_metadata_is_empty(self): + resp = self._make_response(200, '{}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.metadata, {}) + + def test_success_invalid_json_with_json_content_type_returns_raw_string(self): + resp = self._make_response( + 200, + 'not-json', + {'content-type': 'application/json'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not-json') + + def test_success_bytes_content_decoded(self): + resp = self._make_response(200, b'{"x": 1}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'x': 1}) + + # ── Error paths — standard Skyflow format ──────────────────────────────── + def test_error_standard_skyflow_format_extracts_message(self): + body = json.dumps({'error': {'message': 'bad input', 'http_code': 400, 'http_status': 'BAD_REQUEST', 'grpc_code': 3, 'details': []}}) + resp = self._make_response(400, body, {'x-request-id': 'r1'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + e = ctx.exception + self.assertEqual(e.message, 'bad input') + self.assertEqual(e.http_code, 400) + self.assertEqual(e.request_id, 'r1') + self.assertEqual(e.http_status, 'BAD_REQUEST') + self.assertEqual(e.grpc_code, 3) + + def test_error_standard_format_falls_back_to_http_code_when_missing(self): + body = json.dumps({'error': {'message': 'oops'}}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.http_code, 500) + + def test_error_standard_format_falls_back_to_sdk_message_when_missing(self): + body = json.dumps({'error': {}}) + resp = self._make_response(503, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — string error value ───────────────────────────────────── + def test_error_string_error_value_used_as_message(self): + body = json.dumps({'error': 'gateway timed out'}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.message, 'gateway timed out') + + def test_error_empty_string_error_value_falls_back_to_sdk_message(self): + body = json.dumps({'error': ''}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-standard JSON ────────────────────────────────────── + def test_error_no_error_key_uses_sdk_message(self): + body = json.dumps({'message': 'something went wrong'}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_non_dict_json_body_uses_sdk_message(self): + body = json.dumps(['list', 'not', 'dict']) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_numeric_error_value_uses_sdk_message(self): + body = json.dumps({'error': 12345}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-JSON / empty body ────────────────────────────────── + def test_error_empty_body_uses_sdk_message(self): + resp = self._make_response(502, '', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + self.assertEqual(ctx.exception.http_code, 502) + + def test_error_html_body_uses_sdk_message(self): + resp = self._make_response(502, 'Bad Gateway', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + def test_error_plain_text_body_uses_sdk_message(self): + resp = self._make_response(503, 'Service Unavailable', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── error-from-client header ────────────────────────────────────────────── + def test_error_from_client_true_appended_to_details(self): + body = json.dumps({'error': {'message': 'client error', 'http_code': 400, 'details': []}}) + resp = self._make_response(400, body, {'error-from-client': 'true', 'x-request-id': 'r2'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is True for d in ctx.exception.details)) + + def test_error_from_client_false_appended_to_details(self): + body = json.dumps({'error': {'message': 'server error', 'http_code': 500}}) + resp = self._make_response(500, body, {'error-from-client': 'false'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is False for d in ctx.exception.details)) + + def test_error_from_client_initialises_details_when_none(self): + body = json.dumps({'error': {'message': 'err', 'http_code': 400}}) + resp = self._make_response(400, body, {'error-from-client': 'true'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertIsNotNone(ctx.exception.details) + self.assertTrue(len(ctx.exception.details) > 0) + + +class TestConstructInvokeConnectionRequest(unittest.TestCase): + """Tests for construct_invoke_connection_request covering method, body, headers, path/query params.""" + + BASE_URL = 'https://example.com/api' + LOGGER = Mock() + + def _make_request(self, method=RequestMethod.POST, body=None, headers=None, + path_params=None, query_params=None): + return InvokeConnectionRequest( + method=method, + body=body, + headers=headers, + path_params=path_params or {}, + query_params=query_params or {} + ) + + def test_post_with_json_body_prepares_request(self): + req = self._make_request(body={'k': 'v'}, headers={'Content-Type': 'application/json'}) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIn('k', prepared.body) + + def test_get_with_no_body(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'GET') + + def test_urlencoded_body_is_form_encoded(self): + req = self._make_request( + body={'field': 'val'}, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('field=val', prepared.body) + + def test_formdata_body_produces_multipart_request(self): + req = self._make_request( + body={'file_field': 'data'}, + headers={'Content-Type': 'multipart/form-data'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIsNotNone(prepared.body) + + def test_xml_body_contains_xml_tags(self): + req = self._make_request( + body={'item': 'data'}, + headers={'Content-Type': 'text/xml'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('', prepared.body) + + def test_path_params_substituted_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + path_params={'id': '123'} + ) + url_with_placeholder = 'https://example.com/api/{id}/resource' + prepared = construct_invoke_connection_request(req, url_with_placeholder, self.LOGGER) + self.assertIn('123', prepared.url) + self.assertNotIn('{id}', prepared.url) + + def test_query_params_appear_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + query_params={'page': '1', 'limit': '10'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('page=1', prepared.url) + self.assertIn('limit=10', prepared.url) + + def test_invalid_headers_raises_skyflow_error(self): + req = InvokeConnectionRequest(method=RequestMethod.POST, headers='bad-headers') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + + def test_invalid_body_raises_skyflow_error(self): + req = InvokeConnectionRequest( + method=RequestMethod.POST, + body='not-a-dict', + headers={'Content-Type': 'application/json'} + ) + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + + def test_invalid_method_raises_skyflow_error(self): + req = InvokeConnectionRequest(method='INVALID_METHOD') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_METHOD.value) + + def test_trailing_slash_stripped_from_url(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL + '/', self.LOGGER) + self.assertNotIn('//', prepared.url.replace('https://', '')) + + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index c2f9a861..b86087f5 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch, MagicMock import base64 import os +import tempfile from skyflow.error import SkyflowError from skyflow.generated.rest import WordCharacterCount from skyflow.utils import SkyflowMessages @@ -513,16 +514,12 @@ def test_get_detect_run_in_progress_status(self, mock_validate): self.vault_client.get_detect_file_api.return_value = files_api - # Execute - with patch.object(self.detect, "_Detect__parse_deidentify_file_response") as mock_parse: - result = self.detect.get_detect_run(req) + # Execute — IN_PROGRESS is returned directly without going through the parser + result = self.detect.get_detect_run(req) - # Verify IN_PROGRESS handling - mock_parse.assert_called_once() - args = mock_parse.call_args[0][0] - self.assertIsInstance(args, DeidentifyFileResponse) - self.assertEqual(args.status, 'IN_PROGRESS') - self.assertEqual(args.run_id, run_id) + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, 'IN_PROGRESS') + self.assertEqual(result.run_id, run_id) def test_get_transformations_with_shift_dates(self): @@ -711,3 +708,98 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) + + def test_poll_for_processed_file_exception(self): + files_api = Mock() + files_api.with_raw_response = files_api + files_api.get_run.side_effect = Exception("poll error") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect._Detect__poll_for_processed_file("runid", max_wait_time=5) + + def test_save_output_directory_not_exists(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=False): + self.detect._Detect__save_deidentify_file_response_output( + response, "/nonexistent_dir", "file.txt", "file" + ) + + def test_save_output_second_non_redacted_item(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output1 = Mock() + output1.processedFile = base64.b64encode(b"data1").decode() + output1.processedFileType = "redacted_file" + output1.processedFileExtension = "txt" + output2 = Mock() + output2.processedFile = base64.b64encode(b"data2").decode() + output2.processedFileType = "entities" + output2.processedFileExtension = "json" + response = Mock() + response.output = [output1, output2] + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "original.txt", "original" + ) + + def test_save_output_path_traversal_blocked(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + call_count = [0] + + def fake_realpath(p): + call_count[0] += 1 + if call_count[0] == 1: + return "/safe_dir" + return "/outside/path" + + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=True), \ + patch("skyflow.vault.controller._detect.os.path.realpath", side_effect=fake_realpath): + self.detect._Detect__save_deidentify_file_response_output( + response, "/safe_dir", "file.txt", "file" + ) + + def test_save_output_write_exception(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.base64.b64decode", + side_effect=Exception("decode error")), \ + self.assertRaises(Exception): + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "file.txt", "file" + ) + + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") + @patch("skyflow.vault.controller._detect.base64") + def test_deidentify_file_api_error_inside_try(self, mock_base64, mock_validate): + file_content = b"test content" + file_obj = Mock() + file_obj.read.return_value = file_content + file_obj.name = "test.txt" + mock_base64.b64encode.return_value.decode.return_value = "encoded" + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) + req.entities = [] + req.token_format = None + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = None + req.wait_time = None + files_api = Mock() + files_api.with_raw_response = files_api + files_api.deidentify_text.side_effect = Exception("API error inside try") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect.deidentify_file(req) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 4e1a0dda..993cd72a 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -722,6 +722,26 @@ def test_upload_file_with_missing_file_source(self, mock_validate): self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_without_skyflow_id_successful(self, mock_validate): + """Test upload_file succeeds when skyflow_id is None (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/test.txt", + ) + mocked_open = mock_open_func(read_data=b"test file content") + mock_api_response = Mock() + mock_api_response.data = Mock(skyflow_id="generated-id-123") + records_api = self.vault_client.get_records_api.return_value + records_api.with_raw_response.upload_file_v_2.return_value = mock_api_response + with patch('builtins.open', mocked_open): + result = self.vault.upload_file(request) + mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + self.assertIsNone(request.skyflow_id) + self.assertEqual(result.skyflow_id, "generated-id-123") + self.assertIsNone(result.errors) + class TestFileUploadValidation(unittest.TestCase): def setUp(self): self.logger = Mock() @@ -874,3 +894,38 @@ def test_validate_missing_file_source(self): with self.assertRaises(SkyflowError) as error: validate_file_upload_request(self.logger, request) self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + def test_validate_none_skyflow_id_is_allowed(self): + """Test that skyflow_id=None passes validation (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + base64="dGVzdCBmaWxlIGNvbnRlbnQ=", + file_name="test.txt" + ) + self.assertIsNone(request.skyflow_id) + validate_file_upload_request(self.logger, request) + + @patch('os.path.exists') + @patch('os.path.isfile') + def test_validate_file_path_without_skyflow_id(self, mock_isfile, mock_exists): + """Test validation succeeds with file_path and no skyflow_id.""" + mock_exists.return_value = True + mock_isfile.return_value = True + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/file.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_object_without_skyflow_id(self): + """Test validation succeeds with file_object and no skyflow_id.""" + mock_file = Mock() + mock_file.seek = Mock() + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_object=mock_file + ) + validate_file_upload_request(self.logger, request) From 63c49d22cc023c2a8b88d901b2d57620cd7e2fc2 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Wed, 20 May 2026 09:49:24 +0000 Subject: [PATCH 14/23] [AUTOMATED] Private Release 2.0.2.dev0+e564e92 --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8f76225e..db10d438 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.2' +current_version = '2.0.2.dev0+e564e92' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index bc50f210..2ef0d400 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.2' +SDK_VERSION = '2.0.2.dev0+e564e92' From c12dfd52e27320c68939c420873f04b070de00dc Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Wed, 20 May 2026 15:30:01 +0530 Subject: [PATCH 15/23] SK-2833: update change log file (#246) --- CHANGELOG.md | 214 +-------------------------------------------------- 1 file changed, 2 insertions(+), 212 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f63ab2d7..d58ff590 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,215 +1,5 @@ # Changelog -All notable changes to this project will be documented in this file. +All notable changes to this project will be documented as part of the release notes. -## [2.0.2] - 2026-05-06 -### Added -- Dict context support for Conditional Data Access. - -## [2.0.1] - 2026-04-29 -### Fixed -- Fern client re-initialisation on token refresh. - -## [2.0.0] - 2025-11-11 -### Added -- Multi-vault and multi-connection support via fluent builder (`Skyflow.builder()`). -- New typed request and response classes for all vault operations (`InsertRequest`, `GetRequest`, `UpdateRequest`, `DeleteRequest`, `QueryRequest`, `DetokenizeRequest`, `TokenizeRequest`, `FileUploadRequest`). -- Detect API: `deidentify_text`, `reidentify_text`, `deidentify_file`, and `get_detect_run`. -- File upload support via `vault().upload_file()`. -- Flexible credential types: API key, static bearer token, service account credentials string, credentials file path, and `SKYFLOW_CREDENTIALS` environment variable. -- `SkyflowError` now includes `http_code`, `grpc_code`, `http_status`, `request_id`, and `details` fields. -- `set_log_level()` on the client for runtime log level changes. - -### Changed -- Complete rewrite of the SDK public API. See [docs/migrate_to_v2.md](docs/migrate_to_v2.md) for migration instructions. - -## [1.16.0] - 2025-09-23 -### Fixed -- Remote disconnect error in vault operations. - -## [1.15.8] - 2025-09-30 -### Fixed -- Retry logic when `continue_on_error` is set to `true` in insert. - -## [1.15.7] - 2025-09-23 -### Fixed -- Retry handling for errors in insert method. - -## [1.15.6] - 2025-09-22 -### Fixed -- Added retry logic for transient errors. - -## [1.15.5] - 2025-09-18 -### Fixed -- Remote disconnected errors in vault operations. - -## [1.15.4] - 2025-09-12 -### Fixed -- Retry on exception during vault requests. - -## [1.15.3] - 2025-09-12 -### Fixed -- Retry on exception during vault requests. - -## [1.15.2] - 2025-09-12 -### Fixed -- Retry on connection error in insert method. - -## [1.15.1] - 2023-12-07 -## Fixed -- Not receiving tokens when calling Get with options tokens as true. - -## [1.15.0] - 2023-10-30 -## Added -- options tokens support for Get method. - -## [1.14.0] - 2023-09-29 -## Added -- Support for different BYOT modes in Insert method. - -## [1.13.1] - 2023-09-14 -### Changed -- Add `request_index` in responses for insert method. - -## [1.13.0] - 2023-09-04 -### Added -- Added new Query method. - -## [1.12.0] - 2023-09-01 -### Added -- Support for Bulk request with Continue on Error in Detokenize Method -- Support for Continue on Error in Insert Method - -## [1.11.0] - 2023-08-25 -### Added -- Support for BYOT in Insert method. - -## [1.10.1] - 2023-07-28 -### Fixed -- Fixed delete method - -## [1.10.0] - 2023-07-21 -### Added -- Added delete method - -## [1.9.2] - 2023-06-22 -### Fixed -- Multiple record error in get method - -## [1.9.1] - 2023-06-07 -### Fixed -- Fixed bug in metrics - -## [1.9.0] - 2023-06-07 -### Added -- Added redaction type in detokenize - -## [1.8.1] - 2023-03-17 -### Removed -- removed grace period logic in bearer token generation - -## [1.8.0] - 2023-01-10 -### Added -- update and get methods. - -## [1.7.0] - 2022-12-07 -### Added -- `upsert` support for insert method. - -## [1.6.2] - 2022-06-28 - -### Added -- Copyright header to all files -- Security email in README - -## [1.6.1] - 2022-05-17 - -### Fixed - -- Insert with multiple records returning invalid output - -## [1.6.0] - 2022-04-12 - -### Added - -- support for application/x-www-form-urlencoded and multipart/form-data content-type's in connections. - -## [1.5.1] - 2022-03-29 - -### Added - -- Validation to token obtained from `tokenProvider` - -### Fixed - -- Request headers not getting overridden due to case sensitivity - -## [1.5.0] - 2022-03-22 - -### Changed - -- `getById` changed to `get_by_id` -- `invokeConnection`changed to `invoke_connection` -- `generateBearerToken` changed to `generate_bearer_token` -- `generateBearerTokenDromCreds` changed to `generate_bearer_token_from_creds` -- `isExpired` changed to `is_expired` -- `setLogLevel` changed to `set_log_level` - -### Removed - -- `isValid` function -- `GenerateToken` function - -## [1.4.0] - 2022-03-15 - -### Changed - -- deprecated `isValid` in favour of `isExpired` - -## [1.3.0] - 2022-02-24 - -### Added - -- Request ID in error logs and error responses for API Errors -- Caching to accessToken token -- `isValid` method for validating Service Account bearer token - -## [1.2.1] - 2022-01-18 - -### Fixed - -- `generateBearerTokenFromCreds` raising error "invalid credentials" on correct credentials - -## [1.2.0] - 2022-01-04 - -### Added - -- Logging functionality -- `setLogLevel` function for setting the package-level LogLevel -- `generateBearerTokenFromCreds` function which takes credentials as string - -### Changed - -- Renamed and deprecated `GenerateToken` in favor of `generateBearerToken` -- Make `vaultID` and `vaultURL` optional in `Client` constructor - -## [1.1.0] - 2021-11-10 - -### Added - -- `insert` vault API -- `detokenize` vault API -- `getById` vault API -- `invokeConnection` - -## [1.0.1] - 2021-10-26 - -### Changed - -- Package description - -## [1.0.0] - 2021-10-19 - -### Added - -- Service Account Token generation +See [Github](https://github.com/skyflowapi/skyflow-python/releases) or [PyPI](https://pypi.org/project/skyflow/#history) for more details on each released version. From 4f58c35c2cb21920c636c6171dcfa60243eeea12 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Thu, 21 May 2026 11:34:19 +0530 Subject: [PATCH 16/23] SK-2838: use SDK logger for deprecation warnings instead of Python warnings module (#248) --- README.md | 29 +++++++--- samples/vault_api/upload_file.py | 42 +++++++++----- skyflow/client/skyflow.py | 14 ++--- skyflow/utils/_skyflow_messages.py | 7 ++- skyflow/utils/logger/__init__.py | 2 +- skyflow/utils/logger/_log_helpers.py | 12 ++++ skyflow/vault/data/_file_upload_request.py | 9 +-- tests/client/test_skyflow.py | 64 +++++++++------------- 8 files changed, 100 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 23326cca..dbe86ee7 100644 --- a/README.md +++ b/README.md @@ -410,7 +410,9 @@ Refer to [Query your data](https://docs.skyflow.com/query-data/) and [Execute Qu ### Upload File -Upload files to a Skyflow vault using the `upload_file` method. Create a file upload request with the `FileUploadRequest` class, which accepts parameters such as the table name, column name, and Skyflow ID. +Upload files to a Skyflow vault using the `upload_file` method. Create a file upload request with the `FileUploadRequest` class. + +**Upload a file to an existing record:** ```python from skyflow.vault.data import FileUploadRequest @@ -418,13 +420,26 @@ from skyflow.vault.data import FileUploadRequest # Open the file in binary read mode with open('path/to/file.pdf', 'rb') as file_obj: upload_request = FileUploadRequest( - table='documents', # Table name - column_name='attachment', # Column name to store file - skyflow_id='', # Skyflow ID of the record - file_object=file_obj # Pass file object + table='', + column_name='', + skyflow_id='', + file_object=file_obj + ) + + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload:', response) +``` + +**Upload a file and create a new record (omit `skyflow_id`):** + +```python +with open('path/to/file.pdf', 'rb') as file_obj: + upload_request = FileUploadRequest( + table='documents', + column_name='attachment', + file_object=file_obj ) - - # Perform File Upload + response = skyflow_client.vault('').upload_file(upload_request) print('File upload:', response) ``` diff --git a/samples/vault_api/upload_file.py b/samples/vault_api/upload_file.py index df3e8cd0..7c762b4b 100644 --- a/samples/vault_api/upload_file.py +++ b/samples/vault_api/upload_file.py @@ -6,12 +6,16 @@ """ * Skyflow File Upload Example - * + * * This example demonstrates how to: * 1. Configure Skyflow client credentials * 2. Set up vault configuration - * 3. Create a file upload request - * 4. Handle response and errors + * 3. Upload a file to an existing record (with skyflow_id) + * 4. Upload a file and create a new record (without skyflow_id) + * 5. Handle response and errors + * + * Note: All FileUploadRequest parameters must be + * passed as keyword arguments. """ def perform_file_upload(): @@ -35,8 +39,8 @@ def perform_file_upload(): # Step 2: Configure Vault primary_vault_config = { - 'vault_id': '', - 'cluster_id': '', + 'vault_id': '', + 'cluster_id': '', 'env': Env.PROD, 'credentials': credentials } @@ -50,20 +54,28 @@ def perform_file_upload(): .build() ) - # Step 4: Prepare File Upload Data + # Step 4a: Upload a file to an existing record with open('', 'rb') as file_obj: - file_upload_request = FileUploadRequest( - table='', # Table to upload file to - column_name='', # Column to upload file into - file_object=file_obj, # Pass file object - skyflow_id='' # Record ID to associate the file with + upload_request = FileUploadRequest( + table='', + column_name='', + skyflow_id='', + file_object=file_obj ) - # Step 5: Perform File Upload - response = skyflow_client.vault('').upload_file(file_upload_request) + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload to existing record:', response) - # Handle Successful Response - print('File upload successful: ', response) + # Step 4b: Upload a file and create a new record (omit skyflow_id) + with open('', 'rb') as file_obj: + upload_request = FileUploadRequest( + table='', + column_name='', + file_object=file_obj + ) + + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload with new record:', response) except SkyflowError as error: print('Skyflow Specific Error: ', { diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 2255ee50..ebd5ef7d 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -1,10 +1,8 @@ -import warnings from collections import OrderedDict -from typing_extensions import deprecated from skyflow import LogLevel from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_info, Logger +from skyflow.utils.logger import log_info, log_warn, set_active_log_level, Logger from skyflow.utils.constants import OptionField from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \ validate_update_connection_config, validate_credentials, validate_log_level @@ -61,13 +59,9 @@ def set_log_level(self, log_level): self.__builder._Builder__set_log_level(log_level) return self - @deprecated("[DEPRECATED] Use set_log_level() instead.") def update_log_level(self, log_level): - warnings.warn( - SkyflowMessages.Warning.UPDATE_LOG_LEVEL_DEPRECATED.value, - DeprecationWarning, - stacklevel=2, - ) + """.. deprecated:: Use set_log_level() instead. Will be removed in a future release.""" + log_warn(SkyflowMessages.Warning.UPDATE_LOG_LEVEL_DEPRECATED.value) return self.set_log_level(log_level) def get_log_level(self): @@ -227,6 +221,7 @@ def __set_log_level(self, log_level): validate_log_level(self.__logger, log_level) self.__log_level = log_level self.__logger.set_log_level(log_level) + set_active_log_level(log_level) self.__update_vault_client_logger(log_level, self.__logger) log_info(SkyflowMessages.Info.LOGGER_SETUP_DONE.value, self.__logger) log_info(SkyflowMessages.Info.CURRENT_LOG_LEVEL.value.format(self.__log_level), self.__logger) @@ -243,6 +238,7 @@ def __add_skyflow_credentials(self, credentials): def build(self): validate_log_level(self.__logger, self.__log_level) self.__logger.set_log_level(self.__log_level) + set_active_log_level(self.__log_level) for config in self.__vault_list: self.__add_vault_config(config) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 01e15579..8e65ebac 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -4,6 +4,7 @@ error_prefix = f"Skyflow Python SDK {SDK_VERSION}" INFO = "INFO" +WARN = "WARN" ERROR = "ERROR" class SkyflowMessages: @@ -417,11 +418,11 @@ class HttpStatus(Enum): class Warning(Enum): UPDATE_LOG_LEVEL_DEPRECATED = ( - "[DEPRECATED] Skyflow.update_log_level() is deprecated. " - "Use Skyflow.set_log_level() instead — identical behavior." + f"{WARN}: [{error_prefix}] Skyflow.update_log_level() is deprecated. " + "Use Skyflow.set_log_level() instead." ) FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED = ( - "[DEPRECATED] FileUploadRequest: argument order changed. " + f"{WARN}: [{error_prefix}] FileUploadRequest: argument order changed. " "Old positional order: (table, skyflow_id, column_name). " "New order: FileUploadRequest(table, column_name=..., skyflow_id=...)." ) diff --git a/skyflow/utils/logger/__init__.py b/skyflow/utils/logger/__init__.py index 2993b8fc..bce55608 100644 --- a/skyflow/utils/logger/__init__.py +++ b/skyflow/utils/logger/__init__.py @@ -1,2 +1,2 @@ from ._logger import Logger -from ._log_helpers import log_error, log_info, log_error_log \ No newline at end of file +from ._log_helpers import log_error, log_info, log_warn, log_error_log, set_active_log_level \ No newline at end of file diff --git a/skyflow/utils/logger/_log_helpers.py b/skyflow/utils/logger/_log_helpers.py index 3fff980b..1343b55f 100644 --- a/skyflow/utils/logger/_log_helpers.py +++ b/skyflow/utils/logger/_log_helpers.py @@ -2,6 +2,13 @@ from . import Logger from ..constants import ResponseField +_active_log_level = LogLevel.ERROR + + +def set_active_log_level(level): + global _active_log_level + _active_log_level = level + def log_info(message, logger = None): if not logger: @@ -9,6 +16,11 @@ def log_info(message, logger = None): logger.info(message) +def log_warn(message, logger=None): + if not logger: + logger = Logger(_active_log_level) + logger.warn(message) + def log_error_log(message, logger=None): if not logger: logger = Logger(LogLevel.ERROR) diff --git a/skyflow/vault/data/_file_upload_request.py b/skyflow/vault/data/_file_upload_request.py index 6a632b67..c5c08b51 100644 --- a/skyflow/vault/data/_file_upload_request.py +++ b/skyflow/vault/data/_file_upload_request.py @@ -1,7 +1,7 @@ -import warnings from typing import BinaryIO, Optional from skyflow.utils import SkyflowMessages +from skyflow.utils.logger import log_warn class FileUploadRequest: @@ -15,12 +15,7 @@ def __init__(self, file_object: Optional[BinaryIO] = None, file_name: Optional[str] = None): if args: - warnings.warn( - SkyflowMessages.Warning.FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED.value, - DeprecationWarning, - stacklevel=2, - ) - # Old positional order was: (table, skyflow_id, column_name, ...) + log_warn(SkyflowMessages.Warning.FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED.value) skyflow_id = args[0] if args else skyflow_id column_name = args[1] if len(args) > 1 else column_name self.table = table diff --git a/tests/client/test_skyflow.py b/tests/client/test_skyflow.py index 1122448a..5b7ea675 100644 --- a/tests/client/test_skyflow.py +++ b/tests/client/test_skyflow.py @@ -1,5 +1,4 @@ import unittest -import warnings from unittest.mock import patch, Mock from skyflow import LogLevel, Env @@ -427,68 +426,59 @@ def _build_client(self): def test_update_log_level_emits_deprecation_warning(self): client = self._build_client() - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") + with patch('skyflow.client.skyflow.log_warn') as mock_warn: client.update_log_level(LogLevel.INFO) - deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] - self.assertGreaterEqual(len(deprecation_warnings), 1) - self.assertTrue(any("set_log_level" in str(w.message) for w in deprecation_warnings)) - - def test_update_log_level_warning_points_at_caller(self): - client = self._build_client() - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - client.update_log_level(LogLevel.INFO) - self.assertEqual(caught[0].filename, __file__) + mock_warn.assert_called_once() + self.assertIn("set_log_level", mock_warn.call_args[0][0]) def test_update_log_level_delegates_to_set_log_level(self): client = self._build_client() - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - client.update_log_level(LogLevel.INFO) + client.update_log_level(LogLevel.INFO) self.assertEqual(client.get_log_level(), LogLevel.INFO) class TestFileUploadRequestDeprecation(unittest.TestCase): def test_keyword_args_no_warning(self): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: req = FileUploadRequest( table="table", column_name="col", skyflow_id="sky123", ) - self.assertEqual(len(caught), 0) + mock_warn.assert_not_called() self.assertEqual(req.table, "table") self.assertEqual(req.column_name, "col") self.assertEqual(req.skyflow_id, "sky123") + def test_only_table_positional_no_warning(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest("table", column_name="col", skyflow_id="sky123") + mock_warn.assert_not_called() + self.assertEqual(req.table, "table") + def test_old_positional_order_emits_deprecation_warning(self): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: req = FileUploadRequest("table", "sky123", "col") - self.assertEqual(len(caught), 1) - self.assertTrue(issubclass(caught[0].category, DeprecationWarning)) - self.assertIn("FileUploadRequest", str(caught[0].message)) + mock_warn.assert_called_once() + self.assertIn("FileUploadRequest", mock_warn.call_args[0][0]) def test_old_positional_order_remaps_args_correctly(self): - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - req = FileUploadRequest("table", "sky123", "col") + req = FileUploadRequest("table", "sky123", "col") self.assertEqual(req.skyflow_id, "sky123") self.assertEqual(req.column_name, "col") - def test_old_positional_order_warning_points_at_caller(self): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - FileUploadRequest("table", "sky123", "col") - self.assertEqual(caught[0].filename, __file__) - def test_single_positional_arg_emits_warning_and_sets_skyflow_id(self): - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: req = FileUploadRequest("table", "sky123") - self.assertEqual(len(caught), 1) - self.assertTrue(issubclass(caught[0].category, DeprecationWarning)) + mock_warn.assert_called_once() self.assertEqual(req.skyflow_id, "sky123") self.assertIsNone(req.column_name) + + def test_optional_fields_default_to_none(self): + req = FileUploadRequest(table="table") + self.assertIsNone(req.skyflow_id) + self.assertIsNone(req.column_name) + self.assertIsNone(req.file_path) + self.assertIsNone(req.base64) + self.assertIsNone(req.file_object) + self.assertIsNone(req.file_name) From 4ea39dd4890647b040bfef6e6748f10f8a9a6ecb Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Thu, 21 May 2026 06:04:37 +0000 Subject: [PATCH 17/23] [AUTOMATED] Private Release 2.0.2.dev0+4f58c35 --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index db10d438..2ac4083b 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.2.dev0+e564e92' +current_version = '2.0.2.dev0+4f58c35' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 2ef0d400..8b03eae7 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.2.dev0+e564e92' +SDK_VERSION = '2.0.2.dev0+4f58c35' From 961116e2e69798465b8e9fc979399607f1b540dc Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Thu, 21 May 2026 18:23:03 +0530 Subject: [PATCH 18/23] SK-2838: fix redaction type in detokenize interface (#249) --- README.md | 4 +-- samples/vault_api/detokenize_records.py | 4 +-- skyflow/utils/_skyflow_messages.py | 9 ++++--- skyflow/utils/constants.py | 1 + skyflow/utils/validations/_validations.py | 14 ++++++++-- skyflow/vault/controller/_vault.py | 2 +- tests/utils/validations/test__validations.py | 27 +++++++++++++++++++- 7 files changed, 50 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index dbe86ee7..cc2a78a4 100644 --- a/README.md +++ b/README.md @@ -239,8 +239,8 @@ from skyflow.utils.enums import RedactionType detokenize_request = DetokenizeRequest( data=[ - {'token': 'token1', 'redaction': RedactionType.PLAIN_TEXT}, - {'token': 'token2', 'redaction': RedactionType.PLAIN_TEXT} + {'token': 'token1', 'redaction_type': RedactionType.PLAIN_TEXT}, + {'token': 'token2', 'redaction_type': RedactionType.PLAIN_TEXT} ], continue_on_error=True ) diff --git a/samples/vault_api/detokenize_records.py b/samples/vault_api/detokenize_records.py index e93d5a18..d0d10e0c 100644 --- a/samples/vault_api/detokenize_records.py +++ b/samples/vault_api/detokenize_records.py @@ -55,11 +55,11 @@ def perform_detokenization(): detokenize_data = [ { 'token': '', # Token to be detokenized - 'redaction': RedactionType.REDACTED + 'redaction_type': RedactionType.REDACTED }, { 'token': '', # Token to be detokenized - 'redaction': RedactionType.MASKED + 'redaction_type': RedactionType.MASKED } ] diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 8e65ebac..232bd8b0 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -122,7 +122,7 @@ class Error(Enum): INVOKE_CONNECTION_FAILED = f"{error_prefix} Invoke Connection operation failed." INVALID_IDS_TYPE = f"{error_prefix} Validation error. 'ids' has a value of type {{}}. Specify 'ids' as list." - INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction' has a value of type {{}}. Specify 'redaction' as type Skyflow.RedactionType." + INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction_type' has a value of type {{}}. Specify 'redaction_type' as type Skyflow.RedactionType." INVALID_COLUMN_NAME = f"{error_prefix} Validation error. column_name has a value of type {{}}. Specify 'column' as a string." INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. column_values key has a value of type {{}}. Specify column_values key as list." INVALID_COLUMN_VALUES = f"{error_prefix} Validation error. column_values key is an empty list. Specify at least one column value when column_name is passed." @@ -131,7 +131,7 @@ class Error(Enum): INVALID_OFF_SET_VALUE = f"{error_prefix} Validation error. offset key has a value of type {{}}. Specify offset key as integer." INVALID_LIMIT_VALUE = f"{error_prefix} Validation error. limit key has a value of type {{}}. Specify limit key as integer." INVALID_DOWNLOAD_URL_VALUE = f"{error_prefix} Validation error. download_url key has a value of type {{}}. Specify download_url key as boolean." - REDACTION_WITH_TOKENS_NOT_SUPPORTED = f"{error_prefix} Validation error. 'redaction' can't be used when tokens are specified. Remove 'redaction' from payload if tokens are specified." + REDACTION_WITH_TOKENS_NOT_SUPPORTED = f"{error_prefix} Validation error. 'redaction_type' can't be used when tokens are specified. Remove 'redaction_type' from payload if tokens are specified." TOKENS_GET_COLUMN_NOT_SUPPORTED = f"{error_prefix} Validation error. Column name and/or column values can't be used when tokens are specified. Remove unique column values or tokens from the payload." BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED = f"{error_prefix} Validation error. Both Skyflow IDs and column details can't be specified. Either specify Skyflow IDs or unique column details." INVALID_ORDER_BY_VALUE = f"{error_prefix} Validation error. order_by key has a value of type {{}}. Specify order_by key as Skyflow.OrderBy" @@ -139,7 +139,7 @@ class Error(Enum): UPDATE_FIELD_KEY_ERROR = f"{error_prefix} Validation error. Fields are empty in an update payload. Specify at least one field." INVALID_FIELDS_TYPE = f"{error_prefix} Validation error. The 'data' key has a value of type {{}}. Specify 'data' as a dictionary." IDS_KEY_ERROR = f"{error_prefix} Validation error. 'ids' key is missing from the payload. Specify an 'ids' key." - INVALID_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. The 'data' field is invalid. Specify 'data' as a list of dictionaries containing 'token' and 'redaction'." + INVALID_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. The 'data' field is invalid. Specify 'data' as a list of dictionaries containing 'token' and 'redaction_type'." INVALID_DATA_FOR_DETOKENIZE = f"{error_prefix}" EMPTY_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. Tokens are empty in detokenize payload. Specify at lease one token" INVALID_TOKEN_TYPE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens should be of type string." @@ -417,6 +417,9 @@ class HttpStatus(Enum): BAD_REQUEST = "Bad Request" class Warning(Enum): + DETOKENIZE_REDACTION_KEY_DEPRECATED = ( + f"{WARN}: [{error_prefix}] 'redaction' key in detokenize data is deprecated and will be removed in a future version. Use 'redaction_type' instead." + ) UPDATE_LOG_LEVEL_DEPRECATED = ( f"{WARN}: [{error_prefix}] Skyflow.update_log_level() is deprecated. " "Use Skyflow.set_log_level() instead." diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 17ba96e2..05d28380 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -174,6 +174,7 @@ class RequestParameter: VALUE = 'value' COLUMN_GROUP = 'column_group' REDACTION = 'redaction' + REDACTION_TYPE = 'redaction_type' class FileUploadField: diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 6cc2c811..42abe188 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -11,7 +11,7 @@ FileUploadField, DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField, Detect ) -from skyflow.utils.logger import log_info, log_error_log +from skyflow.utils.logger import log_info, log_warn, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest from skyflow.vault.detect._file_input import FileInput @@ -713,7 +713,17 @@ def validate_detokenize_request(logger, request): invalid_input_error_code) token = item.get(ResponseField.TOKEN) - redaction = item.get(RequestParameter.REDACTION, None) + + has_redaction = RequestParameter.REDACTION in item + has_redaction_type = RequestParameter.REDACTION_TYPE in item + + if has_redaction: + log_warn(SkyflowMessages.Warning.DETOKENIZE_REDACTION_KEY_DEPRECATED.value, logger) + + if has_redaction_type: + redaction = item.get(RequestParameter.REDACTION_TYPE) + else: + redaction = item.get(RequestParameter.REDACTION, None) if not isinstance(token, str) or not token: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format(RequestOperation.DETOKENIZE), diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 7d51ee83..fd085e35 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -223,7 +223,7 @@ def detokenize(self, request: DetokenizeRequest): tokens_list = [ V1DetokenizeRecordRequest( token=item.get(ResponseField.TOKEN), - redaction=item.get(RequestParameter.REDACTION, RedactionType.DEFAULT) + redaction=item.get(RequestParameter.REDACTION_TYPE) or item.get(RequestParameter.REDACTION, RedactionType.DEFAULT) ) for item in request.data ] diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index ec4d5bec..c5ad6b79 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -1057,11 +1057,36 @@ def test_validate_detokenize_request_invalid_continue_on_error_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value) def test_validate_detokenize_request_invalid_redaction_type(self): - request = DetokenizeRequest(data=[{"token": "token123", "redaction": "invalid"}], continue_on_error=False) + request = DetokenizeRequest(data=[{"token": "token123", "redaction_type": "invalid"}], continue_on_error=False) with self.assertRaises(SkyflowError) as context: validate_detokenize_request(self.logger, request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) + def test_validate_detokenize_request_deprecated_redaction_key_emits_warn(self): + from unittest.mock import patch + request = DetokenizeRequest(data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT}], continue_on_error=False) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_called_once() + self.assertIn("redaction_type", mock_warn.call_args[0][0]) + + def test_validate_detokenize_request_both_keys_prioritizes_redaction_type_and_warns(self): + from unittest.mock import patch + request = DetokenizeRequest( + data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT, "redaction_type": RedactionType.MASKED}], + continue_on_error=False + ) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_called_once() + + def test_validate_detokenize_request_redaction_type_only_no_warn(self): + from unittest.mock import patch + request = DetokenizeRequest(data=[{"token": "token123", "redaction_type": RedactionType.PLAIN_TEXT}], continue_on_error=False) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_not_called() + def test_validate_deidentify_file_request_wait_time_negative(self): file_input = FileInput(file_path=self.temp_file_path) From 5345434918b2159e8bca285073f9173ded59c78b Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Thu, 21 May 2026 12:53:24 +0000 Subject: [PATCH 19/23] [AUTOMATED] Private Release 2.0.2.dev0+961116e --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 2ac4083b..d356b664 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.2.dev0+4f58c35' +current_version = '2.0.2.dev0+961116e' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 8b03eae7..35b3f3a2 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.2.dev0+4f58c35' +SDK_VERSION = '2.0.2.dev0+961116e' From d821d2c7a35633c1e141d93947de49207c68b669 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow <156889717+saileshwar-skyflow@users.noreply.github.com> Date: Sun, 24 May 2026 14:42:54 +0530 Subject: [PATCH 20/23] SK-2842: improve coverage and revert details in skyflow error class (#252) --- skyflow/error/_skyflow_error.py | 2 +- tests/vault/connection/__init__.py | 0 tests/vault/connection/test_responses.py | 26 +++ .../vault/controller/test__audit_binlookup.py | 27 +++ tests/vault/controller/test__detect.py | 42 +++++ tests/vault/controller/test__vault.py | 50 +++++ tests/vault/data/__init__.py | 0 tests/vault/data/test_responses.py | 108 +++++++++++ tests/vault/detect/__init__.py | 0 tests/vault/detect/test_models.py | 177 ++++++++++++++++++ tests/vault/tokens/__init__.py | 0 tests/vault/tokens/test_responses.py | 38 ++++ 12 files changed, 469 insertions(+), 1 deletion(-) create mode 100644 tests/vault/connection/__init__.py create mode 100644 tests/vault/connection/test_responses.py create mode 100644 tests/vault/controller/test__audit_binlookup.py create mode 100644 tests/vault/data/__init__.py create mode 100644 tests/vault/data/test_responses.py create mode 100644 tests/vault/detect/__init__.py create mode 100644 tests/vault/detect/test_models.py create mode 100644 tests/vault/tokens/__init__.py create mode 100644 tests/vault/tokens/test_responses.py diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index bf472177..6c6cf463 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -12,6 +12,6 @@ def __init__(self, self.http_code = http_code self.grpc_code = grpc_code self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value - self.details = details if details else None + self.details = details if details else [] self.request_id = request_id super().__init__(message) \ No newline at end of file diff --git a/tests/vault/connection/__init__.py b/tests/vault/connection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/connection/test_responses.py b/tests/vault/connection/test_responses.py new file mode 100644 index 00000000..72bd0c56 --- /dev/null +++ b/tests/vault/connection/test_responses.py @@ -0,0 +1,26 @@ +import unittest +from skyflow.vault.connection._invoke_connection_response import InvokeConnectionResponse + + +class TestInvokeConnectionResponse(unittest.TestCase): + def test_repr(self): + r = InvokeConnectionResponse(data={"key": "val"}, metadata={"m": 1}, errors=None) + self.assertIn("ConnectionResponse", repr(r)) + + def test_str(self): + r = InvokeConnectionResponse(data={"key": "val"}) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = InvokeConnectionResponse() + self.assertIsNone(r.data) + self.assertEqual(r.metadata, {}) + self.assertIsNone(r.errors) + + def test_metadata_defaults_to_empty_dict_when_falsy(self): + r = InvokeConnectionResponse(metadata=None) + self.assertEqual(r.metadata, {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__audit_binlookup.py b/tests/vault/controller/test__audit_binlookup.py new file mode 100644 index 00000000..978eb032 --- /dev/null +++ b/tests/vault/controller/test__audit_binlookup.py @@ -0,0 +1,27 @@ +import unittest +from skyflow.vault.controller._audit import Audit +from skyflow.vault.controller._bin_look_up import BinLookUp + + +class TestAudit(unittest.TestCase): + def test_instantiation(self): + a = Audit() + self.assertIsNotNone(a) + + def test_list_returns_none(self): + a = Audit() + self.assertIsNone(a.list()) + + +class TestBinLookUp(unittest.TestCase): + def test_instantiation(self): + b = BinLookUp() + self.assertIsNotNone(b) + + def test_get_returns_none(self): + b = BinLookUp() + self.assertIsNone(b.get()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index b86087f5..f0f2aa87 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -781,6 +781,48 @@ def test_save_output_write_exception(self): response, tmp_dir, "file.txt", "file" ) + def test_save_output_no_file_extension_uses_original_name(self): + """Branches 113->117 and 119->124: processed_file_extension is falsy — safe_ext stays None.""" + with tempfile.TemporaryDirectory() as tmp_dir: + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = None + output.processed_file_extension = None + response = Mock() + response.output = [output] + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "original.txt", "original" + ) + + @patch("skyflow.vault.controller._detect.time.sleep", return_value=None) + def test_poll_unknown_status_then_success(self, mock_sleep): + """Branch 80->65: status is unknown, loops back, then returns SUCCESS.""" + files_api = Mock() + files_api.with_raw_response = files_api + self.vault_client.get_detect_file_api.return_value = files_api + + call_count = {"n": 0} + + def side_effect(*args, **kwargs): + call_count["n"] += 1 + r = Mock() + if call_count["n"] == 1: + r.status = "UNKNOWN_STATUS" + else: + r.status = "SUCCESS" + return Mock(data=r) + + files_api.get_run.side_effect = side_effect + result = self.detect._Detect__poll_for_processed_file("runid", max_wait_time=10) + self.assertEqual(result.status, "SUCCESS") + + def test_get_file_from_request_no_file_no_path_returns_none(self): + """Branch 285->exit: file_input has neither file nor file_path set.""" + req = DeidentifyFileRequest(file=FileInput(file=None, file_path=None)) + result = self.detect._Detect__get_file_from_request(req) + self.assertIsNone(result) + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") @patch("skyflow.vault.controller._detect.base64") def test_deidentify_file_api_error_inside_try(self, mock_base64, mock_validate): diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 993cd72a..5acdf779 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -742,6 +742,56 @@ def test_upload_file_without_skyflow_id_successful(self, mock_validate): self.assertEqual(result.skyflow_id, "generated-id-123") self.assertIsNone(result.errors) + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + @patch("skyflow.vault.controller._vault.open", mock_open(read_data=b"file_content"), create=True) + def test_upload_file_file_path_with_existing_file_name(self, mock_validate): + """Branch 73->76: file_name already set when file_path is present — skips basename call.""" + request = FileUploadRequest( + table=TABLE_NAME, + column_name="file_col", + file_path="/path/to/file.txt", + file_name="already_set.txt" + ) + mock_api = self.vault_client.get_records_api.return_value.with_raw_response + mock_response = Mock() + mock_response.data.skyflow_id = "sky123" + mock_api.upload_file_v_2.return_value = mock_response + + self.vault.upload_file(request) + mock_api.upload_file_v_2.assert_called_once() + + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_file_object_without_name_attr(self, mock_validate): + """Branch 84->89: file_object has no 'name' attr — __get_file_for_file_upload returns None.""" + file_obj = Mock(spec=[]) + request = FileUploadRequest( + table=TABLE_NAME, + column_name="file_col", + file_object=file_obj + ) + mock_api = self.vault_client.get_records_api.return_value.with_raw_response + mock_response = Mock() + mock_response.data.skyflow_id = "sky123" + mock_api.upload_file_v_2.return_value = mock_response + + self.vault.upload_file(request) + mock_api.upload_file_v_2.assert_called_once() + + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_no_file_source_returns_none_file(self, mock_validate): + """Branch 84->89 (elif False): all file sources None — __get_file_for_file_upload returns None.""" + request = FileUploadRequest( + table=TABLE_NAME, + column_name="file_col" + ) + mock_api = self.vault_client.get_records_api.return_value.with_raw_response + mock_response = Mock() + mock_response.data.skyflow_id = "sky123" + mock_api.upload_file_v_2.return_value = mock_response + + self.vault.upload_file(request) + mock_api.upload_file_v_2.assert_called_once() + class TestFileUploadValidation(unittest.TestCase): def setUp(self): self.logger = Mock() diff --git a/tests/vault/data/__init__.py b/tests/vault/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/data/test_responses.py b/tests/vault/data/test_responses.py new file mode 100644 index 00000000..ea9f2be1 --- /dev/null +++ b/tests/vault/data/test_responses.py @@ -0,0 +1,108 @@ +import unittest +from skyflow.vault.data._delete_response import DeleteResponse +from skyflow.vault.data._file_upload_response import FileUploadResponse +from skyflow.vault.data._get_response import GetResponse +from skyflow.vault.data._insert_response import InsertResponse +from skyflow.vault.data._query_response import QueryResponse +from skyflow.vault.data._update_response import UpdateResponse +from skyflow.vault.data._upload_file_request import UploadFileRequest + + +class TestDeleteResponse(unittest.TestCase): + def test_repr(self): + r = DeleteResponse(deleted_ids=["id1"], errors=None) + self.assertIn("DeleteResponse", repr(r)) + self.assertIn("id1", repr(r)) + + def test_str(self): + r = DeleteResponse(deleted_ids=["id1"], errors=None) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = DeleteResponse() + self.assertIsNone(r.deleted_ids) + self.assertIsNone(r.errors) + + +class TestFileUploadResponse(unittest.TestCase): + def test_repr(self): + r = FileUploadResponse(skyflow_id="sky123", errors=None) + self.assertIn("FileUploadResponse", repr(r)) + self.assertIn("sky123", repr(r)) + + def test_str(self): + r = FileUploadResponse(skyflow_id="sky123", errors=None) + self.assertEqual(str(r), repr(r)) + + +class TestGetResponse(unittest.TestCase): + def test_repr(self): + r = GetResponse(data=[{"field": "val"}], errors=None) + self.assertIn("GetResponse", repr(r)) + + def test_str(self): + r = GetResponse(data=[{"field": "val"}], errors=None) + self.assertEqual(str(r), repr(r)) + + def test_none_data_defaults_to_empty_list(self): + r = GetResponse(data=None) + self.assertEqual(r.data, []) + + def test_empty_data_not_replaced(self): + r = GetResponse(data={}) + self.assertEqual(r.data, {}) + + +class TestInsertResponse(unittest.TestCase): + def test_repr(self): + r = InsertResponse(inserted_fields=[{"skyflow_id": "id1"}], errors=None) + self.assertIn("InsertResponse", repr(r)) + + def test_str(self): + r = InsertResponse(inserted_fields=[{"skyflow_id": "id1"}]) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = InsertResponse() + self.assertIsNone(r.inserted_fields) + self.assertIsNone(r.errors) + + +class TestQueryResponse(unittest.TestCase): + def test_repr(self): + r = QueryResponse() + self.assertIn("QueryResponse", repr(r)) + + def test_str(self): + r = QueryResponse() + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = QueryResponse() + self.assertEqual(r.fields, []) + self.assertIsNone(r.errors) + + +class TestUpdateResponse(unittest.TestCase): + def test_repr(self): + r = UpdateResponse(updated_field={"skyflow_id": "id1"}, errors=None) + self.assertIn("UpdateResponse", repr(r)) + + def test_str(self): + r = UpdateResponse(updated_field={"skyflow_id": "id1"}) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = UpdateResponse() + self.assertIsNone(r.updated_field) + self.assertIsNone(r.errors) + + +class TestUploadFileRequest(unittest.TestCase): + def test_instantiation(self): + r = UploadFileRequest() + self.assertIsNotNone(r) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/detect/__init__.py b/tests/vault/detect/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/detect/test_models.py b/tests/vault/detect/test_models.py new file mode 100644 index 00000000..bec65297 --- /dev/null +++ b/tests/vault/detect/test_models.py @@ -0,0 +1,177 @@ +import unittest +import io +from skyflow.vault.detect._deidentify_text_response import DeidentifyTextResponse +from skyflow.vault.detect._reidentify_text_response import ReidentifyTextResponse +from skyflow.vault.detect._entity_info import EntityInfo +from skyflow.vault.detect._file_input import FileInput +from skyflow.vault.detect._text_index import TextIndex +from skyflow.vault.detect._date_transformation import DateTransformation +from skyflow.vault.detect._transformations import Transformations +from skyflow.vault.detect._file import File +from skyflow.utils.enums import DetectEntities + + +class TestTextIndex(unittest.TestCase): + def test_repr(self): + t = TextIndex(start=0, end=4) + self.assertIn("TextIndex", repr(t)) + self.assertIn("0", repr(t)) + + def test_str(self): + t = TextIndex(start=0, end=4) + self.assertEqual(str(t), repr(t)) + + def test_attributes(self): + t = TextIndex(start=5, end=10) + self.assertEqual(t.start, 5) + self.assertEqual(t.end, 10) + + +class TestEntityInfo(unittest.TestCase): + def setUp(self): + self.text_index = TextIndex(0, 4) + self.processed_index = TextIndex(0, 8) + + def test_repr(self): + e = EntityInfo( + token="TOKEN_1", value="John", + text_index=self.text_index, + processed_index=self.processed_index, + entity="NAME", scores={"confidence": 0.9} + ) + self.assertIn("EntityInfo", repr(e)) + self.assertIn("John", repr(e)) + + def test_str(self): + e = EntityInfo( + token="TOKEN_1", value="John", + text_index=self.text_index, + processed_index=self.processed_index, + entity="NAME", scores={} + ) + self.assertEqual(str(e), repr(e)) + + def test_attributes(self): + e = EntityInfo( + token="T", value="v", + text_index=self.text_index, + processed_index=self.processed_index, + entity="EMAIL", scores={"s": 1.0} + ) + self.assertEqual(e.token, "T") + self.assertEqual(e.entity, "EMAIL") + + +class TestDeidentifyTextResponse(unittest.TestCase): + def test_repr(self): + r = DeidentifyTextResponse( + processed_text="[TOKEN_1]", entities=[], word_count=1, char_count=9 + ) + self.assertIn("DeidentifyTextResponse", repr(r)) + + def test_str(self): + r = DeidentifyTextResponse( + processed_text="[TOKEN_1]", entities=[], word_count=1, char_count=9 + ) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = DeidentifyTextResponse( + processed_text="text", entities=[], word_count=1, char_count=4 + ) + self.assertIsNone(r.errors) + + +class TestReidentifyTextResponse(unittest.TestCase): + def test_repr(self): + r = ReidentifyTextResponse(processed_text="John lives in NYC") + self.assertIn("ReidentifyTextResponse", repr(r)) + + def test_str(self): + r = ReidentifyTextResponse(processed_text="John") + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = ReidentifyTextResponse(processed_text="text") + self.assertIsNone(r.errors) + + +class TestFileInput(unittest.TestCase): + def test_repr_with_file(self): + bio = io.BytesIO(b"data") + bio.name = "test.txt" + fi = FileInput(file=bio) + self.assertIn("FileInput", repr(fi)) + + def test_str(self): + fi = FileInput(file_path="/some/path.pdf") + self.assertEqual(str(fi), repr(fi)) + + def test_repr_no_file(self): + fi = FileInput() + self.assertIn("FileInput", repr(fi)) + self.assertIsNone(fi.file) + self.assertIsNone(fi.file_path) + + +class TestDateTransformation(unittest.TestCase): + def test_instantiation(self): + dt = DateTransformation( + max_days=30, min_days=1, + entities=[DetectEntities.DATE] + ) + self.assertEqual(dt.max, 30) + self.assertEqual(dt.min, 1) + self.assertEqual(dt.entities, [DetectEntities.DATE]) + + +class TestTransformations(unittest.TestCase): + def test_instantiation(self): + dt = DateTransformation(max_days=30, min_days=1, entities=[DetectEntities.DATE]) + t = Transformations(shift_dates=dt) + self.assertEqual(t.shift_dates, dt) + + +class TestFile(unittest.TestCase): + def test_properties_with_file(self): + bio = io.BytesIO(b"hello") + bio.name = "test.txt" + f = File(file=bio) + self.assertEqual(f.name, "test.txt") + self.assertEqual(f.size, 5) + self.assertIsNotNone(f.type) + self.assertIsNotNone(f.last_modified) + + def test_properties_without_file(self): + f = File() + self.assertIsNone(f.name) + self.assertIsNone(f.size) + self.assertIsNone(f.type) + self.assertIsNone(f.last_modified) + + def test_seek_without_file(self): + f = File() + result = f.seek(0) + self.assertIsNone(result) + + def test_read_without_file(self): + f = File() + result = f.read() + self.assertIsNone(result) + + def test_seek_with_file(self): + bio = io.BytesIO(b"hello") + bio.name = "t.txt" + f = File(file=bio) + f.seek(0) + self.assertEqual(f.read(), b"hello") + + def test_repr(self): + bio = io.BytesIO(b"hi") + bio.name = "t.txt" + f = File(file=bio) + self.assertIn("File", repr(f)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/tokens/__init__.py b/tests/vault/tokens/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/tokens/test_responses.py b/tests/vault/tokens/test_responses.py new file mode 100644 index 00000000..62f217de --- /dev/null +++ b/tests/vault/tokens/test_responses.py @@ -0,0 +1,38 @@ +import unittest +from skyflow.vault.tokens._detokenize_response import DetokenizeResponse +from skyflow.vault.tokens._tokenize_response import TokenizeResponse + + +class TestDetokenizeResponse(unittest.TestCase): + def test_repr(self): + r = DetokenizeResponse(detokenized_fields=[{"token": "t1", "value": "v1"}], errors=None) + self.assertIn("DetokenizeResponse", repr(r)) + self.assertIn("t1", repr(r)) + + def test_str(self): + r = DetokenizeResponse(detokenized_fields=[{"token": "t1"}]) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = DetokenizeResponse() + self.assertIsNone(r.detokenized_fields) + self.assertIsNone(r.errors) + + +class TestTokenizeResponse(unittest.TestCase): + def test_repr(self): + r = TokenizeResponse(tokenized_fields=[{"value": "val", "token": "tok"}], errors=None) + self.assertIn("TokenizeResponse", repr(r)) + + def test_str(self): + r = TokenizeResponse(tokenized_fields=[{"token": "tok"}]) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = TokenizeResponse() + self.assertIsNone(r.tokenized_fields) + self.assertIsNone(r.errors) + + +if __name__ == "__main__": + unittest.main() From d879372011ebb6fd514e232486ea57ca2df4bc50 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Sun, 24 May 2026 09:13:11 +0000 Subject: [PATCH 21/23] [AUTOMATED] Private Release 2.0.2.dev0+d821d2c --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index d356b664..c670c393 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.2.dev0+961116e' +current_version = '2.0.2.dev0+d821d2c' setup( name='skyflow', diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 35b3f3a2..8b16bd7c 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.2.dev0+961116e' +SDK_VERSION = '2.0.2.dev0+d821d2c' From 75bbb85affac8b1872d7c6d07e5c0e467b252337 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 25 May 2026 09:56:06 +0530 Subject: [PATCH 22/23] SK-2842: configure to show README on PyPI --- setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c670c393..a6ca3133 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,9 @@ raise RuntimeError("skyflow requires Python 3.8+") current_version = '2.0.2.dev0+d821d2c' +with open('README.md', 'r', encoding='utf-8') as f: + long_description = f.read() + setup( name='skyflow', version=current_version, @@ -18,7 +21,8 @@ url='https://github.com/skyflowapi/skyflow-python/', license='LICENSE', description='Skyflow SDK for the Python programming language', - long_description=open('README.rst').read(), + long_description=long_description, + long_description_content_type='text/markdown', install_requires=[ 'python_dateutil >= 2.5.3', 'setuptools >= 75.3.3', From d57377c4d6aca20304b286050114a2d01113e31a Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 25 May 2026 04:56:53 +0000 Subject: [PATCH 23/23] [AUTOMATED] Private Release 2.0.2.dev0+e0253e9 --- setup.py | 2 +- skyflow/utils/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a6ca3133..d4ace25b 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.2.dev0+d821d2c' +current_version = '2.0.2.dev0+e0253e9' with open('README.md', 'r', encoding='utf-8') as f: long_description = f.read() diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 8b16bd7c..949d3423 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.2.dev0+d821d2c' +SDK_VERSION = '2.0.2.dev0+e0253e9'