diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f564510..d58ff590 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,162 +1,5 @@ # Changelog -All notable changes to this project will be documented in this file. +All notable changes to this project will be documented as part of the release notes. -## [1.15.1] - 2023-12-07 -## Fixed -- Not receiving tokens when calling Get with options tokens as true. - -## [1.15.0] - 2023-10-30 -## Added -- options tokens support for Get method. - -## [1.14.0] - 2023-09-29 -## Added -- Support for different BYOT modes in Insert method. - -## [1.13.1] - 2023-09-14 -### Changed -- Add `request_index` in responses for insert method. - -## [1.13.0] - 2023-09-04 -### Added -- Added new Query method. - -## [1.12.0] - 2023-09-01 -### Added -- Support for Bulk request with Continue on Error in Detokenize Method -- Support for Continue on Error in Insert Method - -## [1.11.0] - 2023-08-25 -### Added -- Support for BYOT in Insert method. - -## [1.10.1] - 2023-07-28 -### Fixed -- Fixed delete method - -## [1.10.0] - 2023-07-21 -### Added -- Added delete method - -## [1.9.2] - 2023-06-22 -### Fixed -- Multiple record error in get method - -## [1.9.1] - 2023-06-07 -### Fixed -- Fixed bug in metrics - -## [1.9.0] - 2023-06-07 -### Added -- Added redaction type in detokenize - -## [1.8.1] - 2023-03-17 -### Removed -- removed grace period logic in bearer token generation - -## [1.8.0] - 2023-01-10 -### Added -- update and get methods. - -## [1.7.0] - 2022-12-07 -### Added -- `upsert` support for insert method. - -## [1.6.2] - 2022-06-28 - -### Added -- Copyright header to all files -- Security email in README - -## [1.6.1] - 2022-05-17 - -### Fixed - -- Insert with multiple records returning invalid output - -## [1.6.0] - 2022-04-12 - -### Added - -- support for application/x-www-form-urlencoded and multipart/form-data content-type's in connections. - -## [1.5.1] - 2022-03-29 - -### Added - -- Validation to token obtained from `tokenProvider` - -### Fixed - -- Request headers not getting overridden due to case sensitivity - -## [1.5.0] - 2022-03-22 - -### Changed - -- `getById` changed to `get_by_id` -- `invokeConnection`changed to `invoke_connection` -- `generateBearerToken` changed to `generate_bearer_token` -- `generateBearerTokenDromCreds` changed to `generate_bearer_token_from_creds` -- `isExpired` changed to `is_expired` -- `setLogLevel` changed to `set_log_level` - -### Removed - -- `isValid` function -- `GenerateToken` function - -## [1.4.0] - 2022-03-15 - -### Changed - -- deprecated `isValid` in favour of `isExpired` - -## [1.3.0] - 2022-02-24 - -### Added - -- Request ID in error logs and error responses for API Errors -- Caching to accessToken token -- `isValid` method for validating Service Account bearer token - -## [1.2.1] - 2022-01-18 - -### Fixed - -- `generateBearerTokenFromCreds` raising error "invalid credentials" on correct credentials - -## [1.2.0] - 2022-01-04 - -### Added - -- Logging functionality -- `setLogLevel` function for setting the package-level LogLevel -- `generateBearerTokenFromCreds` function which takes credentials as string - -### Changed - -- Renamed and deprecated `GenerateToken` in favor of `generateBearerToken` -- Make `vaultID` and `vaultURL` optional in `Client` constructor - -## [1.1.0] - 2021-11-10 - -### Added - -- `insert` vault API -- `detokenize` vault API -- `getById` vault API -- `invokeConnection` - -## [1.0.1] - 2021-10-26 - -### Changed - -- Package description - -## [1.0.0] - 2021-10-19 - -### Added - -- Service Account Token generation +See [Github](https://github.com/skyflowapi/skyflow-python/releases) or [PyPI](https://pypi.org/project/skyflow/#history) for more details on each released version. diff --git a/README.md b/README.md index 6a980d3f..cc2a78a4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ # Skyflow Python SDK +> **This is the current, recommended version of the Skyflow SDK.** V2.1.0 brings flexible auth, multi-vault support, native data types, and rich error diagnostics. +> +> Migrating from v1? See the **[Migration Guide](https://github.com/skyflowapi/skyflow-python/blob/main/docs/migrate_to_v2.md)** for step-by-step instructions. V1 is in maintenance mode and will reach End of Life on October 31, 2026. + The Skyflow Python SDK is designed to help with integrating Skyflow into a Python backend. ## Table of Contents @@ -235,8 +239,8 @@ from skyflow.utils.enums import RedactionType detokenize_request = DetokenizeRequest( data=[ - {'token': 'token1', 'redaction': RedactionType.PLAIN_TEXT}, - {'token': 'token2', 'redaction': RedactionType.PLAIN_TEXT} + {'token': 'token1', 'redaction_type': RedactionType.PLAIN_TEXT}, + {'token': 'token2', 'redaction_type': RedactionType.PLAIN_TEXT} ], continue_on_error=True ) @@ -406,7 +410,9 @@ Refer to [Query your data](https://docs.skyflow.com/query-data/) and [Execute Qu ### Upload File -Upload files to a Skyflow vault using the `upload_file` method. Create a file upload request with the `FileUploadRequest` class, which accepts parameters such as the table name, column name, and Skyflow ID. +Upload files to a Skyflow vault using the `upload_file` method. Create a file upload request with the `FileUploadRequest` class. + +**Upload a file to an existing record:** ```python from skyflow.vault.data import FileUploadRequest @@ -414,13 +420,26 @@ from skyflow.vault.data import FileUploadRequest # Open the file in binary read mode with open('path/to/file.pdf', 'rb') as file_obj: upload_request = FileUploadRequest( - table='documents', # Table name - column_name='attachment', # Column name to store file - skyflow_id='', # Skyflow ID of the record - file_object=file_obj # Pass file object + table='', + column_name='', + skyflow_id='', + file_object=file_obj ) - - # Perform File Upload + + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload:', response) +``` + +**Upload a file and create a new record (omit `skyflow_id`):** + +```python +with open('path/to/file.pdf', 'rb') as file_obj: + upload_request = FileUploadRequest( + table='documents', + column_name='attachment', + file_object=file_obj + ) + response = skyflow_client.vault('').upload_file(upload_request) print('File upload:', response) ``` diff --git a/ruff.toml b/ruff.toml index b6795704..aea6cce7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,12 +8,13 @@ exclude = [ "venv", "build", "dist", - "tests" + "tests", + "samples" ] line-length = 120 [lint] -select = ["N"] +select = ["N", "PLR2004"] [lint.pep8-naming] diff --git a/samples/detect_api/deidentify_file.py b/samples/detect_api/deidentify_file.py index 99b4b26e..88f012c9 100644 --- a/samples/detect_api/deidentify_file.py +++ b/samples/detect_api/deidentify_file.py @@ -1,7 +1,14 @@ from skyflow.error import SkyflowError from skyflow import Env, Skyflow, LogLevel from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions -from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep, FileInput +from skyflow.vault.detect import ( + DeidentifyFileRequest, + TokenFormat, + Transformations, + DateTransformation, + Bleep, + FileInput, +) """ * Skyflow Deidentify File Example @@ -11,6 +18,7 @@ * spreadsheets, presentations, structured text. """ + def perform_file_deidentification(): try: # Step 1: Configure Credentials @@ -23,7 +31,7 @@ def perform_file_deidentification(): 'vault_id': '', # Replace with your vault ID 'cluster_id': '', # Replace with your cluster ID 'env': Env.PROD, # Deployment environment - 'credentials': credentials + 'credentials': credentials, } # Step 3: Configure & Initialize Skyflow Client @@ -36,70 +44,66 @@ def perform_file_deidentification(): # Step 4: Create File Object file_path = '' # Replace with your file path - file = open(file_path, 'rb') - # Step 5: Configure Deidentify File Request with all options - deidentify_request = DeidentifyFileRequest( - file=FileInput(file), # File to de-identify (can also provide a file path) - entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect - allow_regex_list=[''], # Optional: Patterns to allow - restrict_regex_list=[''], # Optional: Patterns to restrict - - # Token format configuration - token_format=TokenFormat( - vault_token=[DetectEntities.SSN], # Use vault tokens for these entities - ), - - # Optional: Custom transformations - # transformations=Transformations( - # shift_dates=DateTransformation( - # max_days=30, - # min_days=10, - # entities=[DetectEntities.DOB] - # ) - # ), - - # Output configuration - output_directory='', # Where to save processed file - wait_time=15, # Max wait time in seconds (max 64) - - # Image-specific options - output_processed_image=True, # Include processed image in output - output_ocr_text=True, # Include OCR text in response - masking_method=MaskingMethod.BLACKBOX, # Masking method for images - - # PDF-specific options - pixel_density=15, # Pixel density for PDF processing - max_resolution=2000, # Max resolution for PDF - # Audio-specific options - output_processed_audio=True, # Include processed audio - output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type - - # Audio bleep configuration - - # bleep=Bleep( - # gain=5, # Loudness in dB - # frequency=1000, # Pitch in Hz - # start_padding=0.1, # Padding at start (seconds) - # stop_padding=0.2 # Padding at end (seconds) - # ) - ) - - # Step 6: Call deidentifyFile API - response = skyflow_client.detect().deidentify_file(deidentify_request) + # Step 5: Configure Deidentify File Request and call API + with open(file_path, 'rb') as file: + deidentify_request = DeidentifyFileRequest( + file=FileInput(file), # File to de-identify (can also provide a file path) + entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect + allow_regex_list=[''], # Optional: Patterns to allow + restrict_regex_list=[''], # Optional: Patterns to restrict + # Token format configuration + token_format=TokenFormat( + vault_token=[DetectEntities.SSN], # Use vault tokens for these entities + ), + # Optional: Custom transformations + # transformations=Transformations( + # shift_dates=DateTransformation( + # max_days=30, + # min_days=10, + # entities=[DetectEntities.DOB] + # ) + # ), + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) + # Image-specific options + output_processed_image=True, # Include processed image in output + output_ocr_text=True, # Include OCR text in response + masking_method=MaskingMethod.BLACKBOX, # Masking method for images + # PDF-specific options + pixel_density=15, # Pixel density for PDF processing + max_resolution=2000, # Max resolution for PDF + # Audio-specific options + output_processed_audio=True, # Include processed audio + output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type + # Audio bleep configuration + # bleep=Bleep( + # gain=5, # Loudness in dB + # frequency=1000, # Pitch in Hz + # start_padding=0.1, # Padding at start (seconds) + # stop_padding=0.2 # Padding at end (seconds) + # ) + ) + + # Step 6: Call deidentifyFile API + response = skyflow_client.detect().deidentify_file(deidentify_request) # Handle Successful Response - print("\nDeidentify File Response:", response) + print('\nDeidentify File Response:', response) except SkyflowError as error: # Handle Skyflow-specific errors - print('\nSkyflow Error:', { - 'http_code': error.http_code, - 'grpc_code': error.grpc_code, - 'http_status': error.http_status, - 'message': error.message, - 'details': error.details - }) + print( + '\nSkyflow Error:', + { + 'http_code': error.http_code, + 'grpc_code': error.grpc_code, + 'http_status': error.http_status, + 'message': error.message, + 'details': error.details, + }, + ) except Exception as error: # Handle unexpected errors print('Unexpected Error:', error) diff --git a/samples/detect_api/deidentify_file_async.py b/samples/detect_api/deidentify_file_async.py index 579dab2e..23d2f40f 100644 --- a/samples/detect_api/deidentify_file_async.py +++ b/samples/detect_api/deidentify_file_async.py @@ -1,7 +1,14 @@ from skyflow.error import SkyflowError from skyflow import Env, Skyflow, LogLevel from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions -from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep, FileInput +from skyflow.vault.detect import ( + DeidentifyFileRequest, + TokenFormat, + Transformations, + DateTransformation, + Bleep, + FileInput, +) from concurrent.futures import ThreadPoolExecutor """ @@ -25,7 +32,7 @@ def perform_file_deidentification_async(): 'vault_id': '', # Replace with your vault ID 'cluster_id': '', # Replace with your cluster ID 'env': Env.PROD, # Deployment environment - 'credentials': credentials + 'credentials': credentials, } # Step 3: Configure & Initialize Skyflow Client @@ -38,18 +45,15 @@ def perform_file_deidentification_async(): # Step 4: Create File Object file_path = '' # Replace with your file path - deidentify_request = DeidentifyFileRequest( file=FileInput(file_path=file_path), # File to de-identify # entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect allow_regex_list=[''], # Optional: Patterns to allow restrict_regex_list=[''], # Optional: Patterns to restrict - # Token format configuration token_format=TokenFormat( vault_token=[DetectEntities.SSN], # Use vault tokens for these entities ), - # Optional: Custom transformations # transformations=Transformations( # shift_dates=DateTransformation( @@ -58,26 +62,20 @@ def perform_file_deidentification_async(): # entities=[DetectEntities.DOB] # ) # ), - - # Output configuration - output_directory='', # Where to save processed file - wait_time=15, # Max wait time in seconds (max 64) - + # Output configuration + output_directory='', # Where to save processed file + wait_time=15, # Max wait time in seconds (max 64) # Image-specific options output_processed_image=True, # Include processed image in output output_ocr_text=True, # Include OCR text in response masking_method=MaskingMethod.BLACKBOX, # Masking method for images - # PDF-specific options pixel_density=15, # Pixel density for PDF processing max_resolution=2000, # Max resolution for PDF - # Audio-specific options output_processed_audio=True, # Include processed audio output_transcription=DetectOutputTranscriptions.PLAINTEXT_TRANSCRIPTION, # Transcription type - # Audio bleep configuration - # bleep=Bleep( # gain=5, # Loudness in dB # frequency=1000, # Pitch in Hz @@ -85,35 +83,36 @@ def perform_file_deidentification_async(): # stop_padding=0.2 # Padding at end (seconds) # ) ) - + # Create a thread pool executor executor = ThreadPoolExecutor(max_workers=1) - - future = executor.submit( - lambda: skyflow_client.detect().deidentify_file(deidentify_request) - ) - + + future = executor.submit(lambda: skyflow_client.detect().deidentify_file(deidentify_request)) + def handle_response(future): exception = future.exception() if exception is not None: if isinstance(exception, SkyflowError): # Handle Skyflow-specific errors - print('\nSkyflow Error:', { - 'http_code': exception.http_code, - 'grpc_code': exception.grpc_code, - 'http_status': exception.http_status, - 'message': exception.message, - 'details': exception.details - }) + print( + '\nSkyflow Error:', + { + 'http_code': exception.http_code, + 'grpc_code': exception.grpc_code, + 'http_status': exception.http_status, + 'message': exception.message, + 'details': exception.details, + }, + ) else: # Handle unexpected errors print('Unexpected Error:', exception) return - + # Handle Successful Response result = future.result() - print("\nDeidentify File Response:", result) - + print('\nDeidentify File Response:', result) + future.add_done_callback(handle_response) executor.shutdown(wait=True) @@ -121,4 +120,3 @@ def handle_response(future): except Exception as error: # Handle unexpected errors print('Unexpected Error:', error) - diff --git a/samples/service_account/signed_token_generation_example.py b/samples/service_account/signed_token_generation_example.py index 6ede1746..7ae175cd 100644 --- a/samples/service_account/signed_token_generation_example.py +++ b/samples/service_account/signed_token_generation_example.py @@ -1,12 +1,10 @@ import json from skyflow.service_account import ( - is_expired, generate_signed_data_tokens, generate_signed_data_tokens_from_creds, ) -file_path = 'CREDENTIALS_FILE_PATH' -bearer_token = '' +file_path = '' skyflow_credentials = { 'clientID': '', @@ -19,15 +17,18 @@ # Approach 1: Signed data tokens with string context +# Returns: [('', ''), ...] def get_signed_tokens_with_string_context(): options = { 'ctx': 'user_12345', - 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'data_tokens': ['', ''], 'time_to_live': 90, # in seconds } try: - data_token, signed_data_token = generate_signed_data_tokens(file_path, options) - return data_token, signed_data_token + results = generate_signed_data_tokens(file_path, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results except Exception as e: print(f'Error: {str(e)}') @@ -42,12 +43,14 @@ def get_signed_tokens_with_object_context(): 'department': 'research', 'user_id': 'user_67890', }, - 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'data_tokens': ['', ''], 'time_to_live': 90, } try: - data_token, signed_data_token = generate_signed_data_tokens(file_path, options) - return data_token, signed_data_token + results = generate_signed_data_tokens(file_path, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results except Exception as e: print(f'Error: {str(e)}') @@ -56,16 +59,21 @@ def get_signed_tokens_with_object_context(): def get_signed_tokens_from_credentials_string(): options = { 'ctx': 'user_12345', - 'data_tokens': ['DATA_TOKEN1', 'DATA_TOKEN2'], + 'data_tokens': ['', ''], 'time_to_live': 90, } try: - data_token, signed_data_token = generate_signed_data_tokens_from_creds(credentials_string, options) - return data_token, signed_data_token + results = generate_signed_data_tokens_from_creds(credentials_string, options) + for data_token, signed_data_token in results: + print(f' Token: {data_token}, Signed Token: {signed_data_token}') + return results except Exception as e: print(f'Error: {str(e)}') -print("String context:", get_signed_tokens_with_string_context()) -print("Object context:", get_signed_tokens_with_object_context()) -print("Creds string:", get_signed_tokens_from_credentials_string()) +print('String context:') +get_signed_tokens_with_string_context() +print('Object context:') +get_signed_tokens_with_object_context() +print('Creds string:') +get_signed_tokens_from_credentials_string() diff --git a/samples/service_account/token_generation_example.py b/samples/service_account/token_generation_example.py index 34db4c37..32fa022b 100644 --- a/samples/service_account/token_generation_example.py +++ b/samples/service_account/token_generation_example.py @@ -5,7 +5,7 @@ is_expired, ) -file_path = 'CREDENTIALS_FILE_PATH' +file_path = '' bearer_token = '' # To generate Bearer Token from credentials string. @@ -46,10 +46,9 @@ def get_bearer_token_from_credentials_string(): bearer_token = token return bearer_token except Exception as e: - print(f"Error generating token from credentials string: {str(e)}") - + print(f'Error generating token from credentials string: {str(e)}') print(get_bearer_token_from_file_path()) -print(get_bearer_token_from_credentials_string()) \ No newline at end of file +print(get_bearer_token_from_credentials_string()) diff --git a/samples/vault_api/credentials_options.py b/samples/vault_api/credentials_options.py index db792042..2155f99d 100644 --- a/samples/vault_api/credentials_options.py +++ b/samples/vault_api/credentials_options.py @@ -13,6 +13,7 @@ 4. Handle response and errors """ + def perform_secure_data_deletion(): try: # Step 1: Configure Bearer Token Credentials @@ -31,10 +32,10 @@ def perform_secure_data_deletion(): } secondary_vault_config = { - 'vault_id': 'YOUR_SECONDARY_VAULT_ID', # Secondary vault - 'cluster_id': 'YOUR_SECONDARY_CLUSTER_ID', # Cluster ID from your vault URL + 'vault_id': '', # Secondary vault + 'cluster_id': '', # Cluster ID from your vault URL 'env': Env.PROD, # Deployment environment - 'credentials': credentials + 'credentials': credentials, } # Step 3: Configure & Initialize Skyflow Client @@ -51,13 +52,10 @@ def perform_secure_data_deletion(): primary_table_name = '' # Replace with actual table name - primary_delete_request = DeleteRequest( - table=primary_table_name, - ids=primary_delete_ids - ) + primary_delete_request = DeleteRequest(table=primary_table_name, ids=primary_delete_ids) # Perform Delete Operation for Primary Vault - primary_delete_response = skyflow_client.vault('').delete(primary_delete_request) + primary_delete_response = skyflow_client.vault('').delete(primary_delete_request) # Handle Successful Response print('Primary Vault Deletion Successful:', primary_delete_response) @@ -67,10 +65,7 @@ def perform_secure_data_deletion(): secondary_table_name = '' # Replace with actual table name - secondary_delete_request = DeleteRequest( - table=secondary_table_name, - ids=secondary_delete_ids - ) + secondary_delete_request = DeleteRequest(table=secondary_table_name, ids=secondary_delete_ids) # Perform Delete Operation for Secondary Vault secondary_delete_response = skyflow_client.vault('').delete(secondary_delete_request) @@ -78,17 +73,12 @@ def perform_secure_data_deletion(): # Handle Successful Response print('Secondary Vault Deletion Successful:', secondary_delete_response) - except SkyflowError as error: # Comprehensive Error Handling - print('Skyflow Specific Error: ', { - 'code': error.http_code, - 'message': error.message, - 'details': error.details - }) + print('Skyflow Specific Error: ', {'code': error.http_code, 'message': error.message, 'details': error.details}) except Exception as error: print('Unexpected Error:', error) # Invoke the secure data deletion function -perform_secure_data_deletion() \ No newline at end of file +perform_secure_data_deletion() diff --git a/samples/vault_api/detokenize_records.py b/samples/vault_api/detokenize_records.py index e93d5a18..d0d10e0c 100644 --- a/samples/vault_api/detokenize_records.py +++ b/samples/vault_api/detokenize_records.py @@ -55,11 +55,11 @@ def perform_detokenization(): detokenize_data = [ { 'token': '', # Token to be detokenized - 'redaction': RedactionType.REDACTED + 'redaction_type': RedactionType.REDACTED }, { 'token': '', # Token to be detokenized - 'redaction': RedactionType.MASKED + 'redaction_type': RedactionType.MASKED } ] diff --git a/samples/vault_api/get_records.py b/samples/vault_api/get_records.py index b2fd445f..9e4d031a 100644 --- a/samples/vault_api/get_records.py +++ b/samples/vault_api/get_records.py @@ -4,6 +4,7 @@ from skyflow import Skyflow, LogLevel from skyflow.vault.data import GetRequest + def perform_secure_data_retrieval(): try: # Step 1: Configure Credentials @@ -28,7 +29,7 @@ def perform_secure_data_retrieval(): 'vault_id': '', # primary vault 'cluster_id': '', # Cluster ID from your vault URL 'env': Env.PROD, # Deployment environment (PROD by default) - 'credentials': credentials # Authentication method + 'credentials': credentials, # Authentication method } # Step 3: Configure & Initialize Skyflow Client @@ -42,10 +43,10 @@ def perform_secure_data_retrieval(): # Step 4: Prepare Retrieval Data - get_ids = ['', 'SKYFLOW_ID2'] + get_ids = ['', ''] get_request = GetRequest( - table='', # Replace with your actual table name + table='', # Replace with your actual table name ids=get_ids, ) @@ -57,15 +58,11 @@ def perform_secure_data_retrieval(): except SkyflowError as error: # Comprehensive Error Handling - print('Skyflow Specific Error: ', { - 'code': error.http_code, - 'message': error.message, - 'details': error.details - }) + print('Skyflow Specific Error: ', {'code': error.http_code, 'message': error.message, 'details': error.details}) except Exception as error: print('Unexpected Error:', error) # Invoke the secure data retrieval function -perform_secure_data_retrieval() \ No newline at end of file +perform_secure_data_retrieval() diff --git a/samples/vault_api/upload_file.py b/samples/vault_api/upload_file.py index df3e8cd0..7c762b4b 100644 --- a/samples/vault_api/upload_file.py +++ b/samples/vault_api/upload_file.py @@ -6,12 +6,16 @@ """ * Skyflow File Upload Example - * + * * This example demonstrates how to: * 1. Configure Skyflow client credentials * 2. Set up vault configuration - * 3. Create a file upload request - * 4. Handle response and errors + * 3. Upload a file to an existing record (with skyflow_id) + * 4. Upload a file and create a new record (without skyflow_id) + * 5. Handle response and errors + * + * Note: All FileUploadRequest parameters must be + * passed as keyword arguments. """ def perform_file_upload(): @@ -35,8 +39,8 @@ def perform_file_upload(): # Step 2: Configure Vault primary_vault_config = { - 'vault_id': '', - 'cluster_id': '', + 'vault_id': '', + 'cluster_id': '', 'env': Env.PROD, 'credentials': credentials } @@ -50,20 +54,28 @@ def perform_file_upload(): .build() ) - # Step 4: Prepare File Upload Data + # Step 4a: Upload a file to an existing record with open('', 'rb') as file_obj: - file_upload_request = FileUploadRequest( - table='', # Table to upload file to - column_name='', # Column to upload file into - file_object=file_obj, # Pass file object - skyflow_id='' # Record ID to associate the file with + upload_request = FileUploadRequest( + table='', + column_name='', + skyflow_id='', + file_object=file_obj ) - # Step 5: Perform File Upload - response = skyflow_client.vault('').upload_file(file_upload_request) + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload to existing record:', response) - # Handle Successful Response - print('File upload successful: ', response) + # Step 4b: Upload a file and create a new record (omit skyflow_id) + with open('', 'rb') as file_obj: + upload_request = FileUploadRequest( + table='', + column_name='', + file_object=file_obj + ) + + response = skyflow_client.vault('').upload_file(upload_request) + print('File upload with new record:', response) except SkyflowError as error: print('Skyflow Specific Error: ', { diff --git a/setup.py b/setup.py index 140dc870..d4ace25b 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,10 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.2' +current_version = '2.0.2.dev0+e0253e9' + +with open('README.md', 'r', encoding='utf-8') as f: + long_description = f.read() setup( name='skyflow', @@ -18,11 +21,12 @@ url='https://github.com/skyflowapi/skyflow-python/', license='LICENSE', description='Skyflow SDK for the Python programming language', - long_description=open('README.rst').read(), + long_description=long_description, + long_description_content_type='text/markdown', install_requires=[ 'python_dateutil >= 2.5.3', - 'setuptools >= 21.0.0', - 'urllib3 >= 1.25.3, < 2.1.0', + 'setuptools >= 75.3.3', + 'urllib3 >= 1.25.3, <= 2.6.3', 'pydantic >= 2', 'typing-extensions >= 4.7.1', 'DateTime~=5.5', diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 9f0d9dbf..ebd5ef7d 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -2,7 +2,8 @@ from skyflow import LogLevel from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_info, Logger +from skyflow.utils.logger import log_info, log_warn, set_active_log_level, Logger +from skyflow.utils.constants import OptionField from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \ validate_update_connection_config, validate_credentials, validate_log_level from skyflow.vault.client.client import VaultClient @@ -30,7 +31,7 @@ def update_vault_config(self,config): self.__builder.update_vault_config(config) def get_vault_config(self, vault_id): - return self.__builder.get_vault_config(vault_id).get("vault_client").get_config() + return self.__builder.get_vault_config(vault_id).get(OptionField.VAULT_CLIENT).get_config() def add_connection_config(self, config): self.__builder._Builder__add_connection_config(config) @@ -45,7 +46,7 @@ def update_connection_config(self, config): return self def get_connection_config(self, connection_id): - return self.__builder.get_connection_config(connection_id).get("vault_client").get_config() + return self.__builder.get_connection_config(connection_id).get(OptionField.VAULT_CLIENT).get_config() def add_skyflow_credentials(self, credentials): self.__builder._Builder__add_skyflow_credentials(credentials) @@ -58,23 +59,25 @@ def set_log_level(self, log_level): self.__builder._Builder__set_log_level(log_level) return self + def update_log_level(self, log_level): + """.. deprecated:: Use set_log_level() instead. Will be removed in a future release.""" + log_warn(SkyflowMessages.Warning.UPDATE_LOG_LEVEL_DEPRECATED.value) + return self.set_log_level(log_level) + def get_log_level(self): return self.__builder._Builder__log_level - def update_log_level(self, log_level): - self.__builder._Builder__set_log_level(log_level) - def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("vault_controller") + return vault_config.get(OptionField.VAULT_CONTROLLER) def connection(self, connection_id = None) -> Connection: connection_config = self.__builder.get_connection_config(connection_id) - return connection_config.get("controller") + return connection_config.get(OptionField.CONTROLLER) def detect(self, vault_id = None) -> Detect: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("detect_controller") + return vault_config.get(OptionField.DETECT_CONTROLLER) class Builder: def __init__(self): @@ -87,13 +90,13 @@ def __init__(self): self.__logger = Logger(LogLevel.ERROR) def add_vault_config(self, config): - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) if not isinstance(vault_id, str) or not vault_id: raise SkyflowError( SkyflowMessages.Error.INVALID_VAULT_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if vault_id in [vault.get("vault_id") for vault in self.__vault_list]: + if vault_id in [vault.get(OptionField.VAULT_ID) for vault in self.__vault_list]: log_info(SkyflowMessages.Info.VAULT_CONFIG_EXISTS.value.format(vault_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(vault_id), @@ -112,9 +115,11 @@ def remove_vault_config(self, vault_id): def update_vault_config(self, config): validate_update_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) + if vault_id not in self.__vault_configs: + raise SkyflowError(SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(vault_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value) vault_config = self.__vault_configs[vault_id] - vault_config.get("vault_client").update_config(config) + vault_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_vault_config(self, vault_id): if vault_id is None: @@ -129,13 +134,13 @@ def get_vault_config(self, vault_id): def add_connection_config(self, config): - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) if not isinstance(connection_id, str) or not connection_id: raise SkyflowError( SkyflowMessages.Error.INVALID_CONNECTION_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if connection_id in [connection.get("connection_id") for connection in self.__connection_list]: + if connection_id in [connection.get(OptionField.CONNECTION_ID) for connection in self.__connection_list]: log_info(SkyflowMessages.Info.CONNECTION_CONFIG_EXISTS.value.format(connection_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id), @@ -153,9 +158,11 @@ def remove_connection_config(self, connection_id): def update_connection_config(self, config): validate_update_connection_config(self.__logger, config) - connection_id = config['connection_id'] + connection_id = config[OptionField.CONNECTION_ID] + if connection_id not in self.__connection_configs: + raise SkyflowError(SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(connection_id), SkyflowMessages.ErrorCodes.INVALID_INPUT.value) connection_config = self.__connection_configs[connection_id] - connection_config.get("vault_client").update_config(config) + connection_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_connection_config(self, connection_id): if connection_id is None: @@ -183,37 +190,38 @@ def get_logger(self): def __add_vault_config(self, config): validate_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_client = VaultClient(config) self.__vault_configs[vault_id] = { - "vault_client": vault_client, - "vault_controller": Vault(vault_client), - "detect_controller": Detect(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.VAULT_CONTROLLER: Vault(vault_client), + OptionField.DETECT_CONTROLLER: Detect(vault_client) } - log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) - log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) + log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) + log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) def __add_connection_config(self, config): validate_connection_config(self.__logger, config) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) vault_client = VaultClient(config) self.__connection_configs[connection_id] = { - "vault_client": vault_client, - "controller": Connection(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.CONTROLLER: Connection(vault_client) } - log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get("connection_id")), self.__logger) + log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.CONNECTION_ID)), self.__logger) def __update_vault_client_logger(self, log_level, logger): for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_logger(log_level,logger) + vault_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_logger(log_level,logger) + connection_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) def __set_log_level(self, log_level): validate_log_level(self.__logger, log_level) self.__log_level = log_level self.__logger.set_log_level(log_level) + set_active_log_level(log_level) self.__update_vault_client_logger(log_level, self.__logger) log_info(SkyflowMessages.Info.LOGGER_SETUP_DONE.value, self.__logger) log_info(SkyflowMessages.Info.CURRENT_LOG_LEVEL.value.format(self.__log_level), self.__logger) @@ -223,13 +231,14 @@ def __add_skyflow_credentials(self, credentials): self.__skyflow_credentials = credentials validate_credentials(self.__logger, credentials) for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_common_skyflow_credentials(credentials) + vault_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(credentials) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials) + connection_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(self.__skyflow_credentials) def build(self): validate_log_level(self.__logger, self.__log_level) self.__logger.set_log_level(self.__log_level) + set_active_log_level(self.__log_level) for config in self.__vault_list: self.__add_vault_config(config) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index fca43935..cda064e2 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -1,5 +1,4 @@ from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_error class SkyflowError(Exception): def __init__(self, @@ -8,11 +7,11 @@ def __init__(self, request_id = None, grpc_code = None, http_status = None, - details = []): + details = None): self.message = message self.http_code = http_code self.grpc_code = grpc_code self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value - self.details = details + self.details = details if details else [] self.request_id = request_id - super().__init__() \ No newline at end of file + super().__init__(message) diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 9f30b789..deccf973 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -3,10 +3,14 @@ import re import time import jwt +from urllib.parse import urlparse from skyflow.error import SkyflowError from skyflow.service_account.client.auth_client import AuthClient from skyflow.utils.logger import log_info, log_error_log from skyflow.utils import get_base_url, format_scope, SkyflowMessages +from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField +from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError +from skyflow.utils import is_valid_url from skyflow.utils.constants import CTX_KEY_REGEX @@ -14,6 +18,18 @@ _CTX_KEY_PATTERN = re.compile(CTX_KEY_REGEX) +_SNAKE_TO_CAMEL_CRED_MAP = { + 'private_key': CredentialField.PRIVATE_KEY, + 'client_id': CredentialField.CLIENT_ID, + 'key_id': CredentialField.KEY_ID, + 'token_uri': CredentialField.TOKEN_URI, + 'client_name': CredentialField.CLIENT_NAME, +} + + +def _normalize_credentials(credentials): + return {_SNAKE_TO_CAMEL_CRED_MAP.get(k, k): v for k, v in credentials.items()} + def _validate_and_resolve_ctx(ctx): """Validate ctx value and return resolved value for JWT claims. @@ -43,14 +59,16 @@ def _validate_and_resolve_ctx(ctx): ) def is_expired(token, logger = None): + if token is None: + return True if len(token) == 0: log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True try: decoded = jwt.decode( - token, options={"verify_signature": False, "verify_aud": False}) - if time.time() >= decoded['exp']: + token, options={OptionField.VERIFY_SIGNATURE: False, OptionField.VERIFY_AUD: False}) + if time.time() >= decoded[JwtField.EXP]: log_info(SkyflowMessages.Info.BEARER_TOKEN_EXPIRED.value, logger) log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -62,20 +80,18 @@ def is_expired(token, logger = None): return True def generate_bearer_token(credentials_file_path, options = None, logger = None): + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) try: - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) - credentials_file =open(credentials_file_path, 'r') + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) + except SkyflowError: + raise except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger) - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) - - finally: - credentials_file.close() result = get_service_account_token(credentials, options, logger) return result @@ -90,26 +106,37 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) return result def get_service_account_token(credentials, options, logger): + credentials = _normalize_credentials(credentials) try: - private_key = credentials["privateKey"] - except: - log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) + private_key = credentials[CredentialField.PRIVATE_KEY] + except KeyError: + log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: - client_id = credentials["clientID"] - except: + client_id = credentials[CredentialField.CLIENT_ID] + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: - key_id = credentials["keyID"] - except: + key_id = credentials[CredentialField.KEY_ID] + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: - token_uri = credentials["tokenURI"] - except: + token_uri = credentials[CredentialField.TOKEN_URI] + except KeyError: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) + + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + if options and CredentialField.TOKEN_URI_OPTION in options: + token_uri = options[CredentialField.TOKEN_URI_OPTION] + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) signed_token = get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger) base_url = get_base_url(token_uri) @@ -117,81 +144,92 @@ def get_service_account_token(credentials, options, logger): auth_api = auth_client.get_auth_api() formatted_scope = None - if options and "role_ids" in options: - formatted_scope = format_scope(options.get("role_ids")) + if options and OptionField.ROLE_IDS in options: + formatted_scope = format_scope(options.get(OptionField.ROLE_IDS)) - response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + try: + response = auth_api.authentication_service_get_auth_token(assertion = signed_token, + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) - log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) + except UnauthorizedError: + log_error_log(SkyflowMessages.ErrorLogs.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value, invalid_input_error_code) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.FAILED_TO_GET_BEARER_TOKEN.value, logger=logger) + raise SkyflowError(SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value, invalid_input_error_code) return response.access_token, response.token_type def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): payload = { - "iss": client_id, - "key": key_id, - "aud": token_uri, - "sub": client_id, - "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60) + JwtField.ISS: client_id, + JwtField.KEY: key_id, + JwtField.AUD: token_uri, + JwtField.SUB: client_id, + JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and "ctx" in options: - resolved_ctx = _validate_and_resolve_ctx(options.get("ctx")) + if options and OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options.get(OptionField.CTX)) if resolved_ctx is not None: - payload["ctx"] = resolved_ctx + payload[JwtField.CTX] = resolved_ctx try: - return jwt.encode(payload=payload, key=private_key, algorithm="RS256") + return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code) def get_signed_tokens(credentials_obj, options): - try: - expiry_time = int(time.time()) + options.get("time_to_live", 60) - prefix = "signed_token_" - - if options and options.get("data_tokens"): - for token in options["data_tokens"]: - claims = { - "iss": "sdk", - "key": credentials_obj.get("keyID"), - "exp": expiry_time, - "sub": credentials_obj.get("clientID"), - "tok": token, - "iat": int(time.time()), - } - - if "ctx" in options: - resolved_ctx = _validate_and_resolve_ctx(options["ctx"]) - if resolved_ctx is not None: - claims["ctx"] = resolved_ctx - - private_key = credentials_obj.get("privateKey") - signed_jwt = jwt.encode(claims, private_key, algorithm="RS256") - response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) - log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) - return response_object - - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + options = options if options is not None else {} + credentials_obj = _normalize_credentials(credentials_obj) + expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) + prefix = JWT.SIGNED_TOKEN_PREFIX + + token_uri = credentials_obj.get(CredentialField.TOKEN_URI) + if not isinstance(token_uri, str) or not is_valid_url(token_uri): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) + + resolved_ctx = None + if OptionField.CTX in options: + resolved_ctx = _validate_and_resolve_ctx(options[OptionField.CTX]) + + results = [] + if options and options.get(OptionField.DATA_TOKENS): + for token in options[OptionField.DATA_TOKENS]: + claims = { + JwtField.ISS: JWT.ISSUER_SDK, + JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID), + JwtField.EXP: expiry_time, + JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID), + JwtField.TOK: token, + JwtField.IAT: int(time.time()), + } + if resolved_ctx is not None: + claims[JwtField.CTX] = resolved_ctx + private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) + try: + signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + results.append(get_signed_data_token_response_object(prefix + signed_jwt, token)) + log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) + return results def generate_signed_data_tokens(credentials_file_path, options): log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value) try: - credentials_file =open(credentials_file_path, 'r') + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), + invalid_input_error_code) + except SkyflowError: + raise except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), - invalid_input_error_code) - - finally: - credentials_file.close() - return get_signed_tokens(credentials, options) def generate_signed_data_tokens_from_creds(credentials, options): @@ -204,9 +242,6 @@ def generate_signed_data_tokens_from_creds(credentials, options): raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value, invalid_input_error_code) return get_signed_tokens(json_credentials, options) + def get_signed_data_token_response_object(signed_token, actual_token): - response_object = { - "token": actual_token, - "signed_token": signed_token - } - return response_object.get("token"), response_object.get("signed_token") + return actual_token, signed_token diff --git a/skyflow/utils/__init__.py b/skyflow/utils/__init__.py index f2788b11..664cf65d 100644 --- a/skyflow/utils/__init__.py +++ b/skyflow/utils/__init__.py @@ -1,5 +1,5 @@ from ..utils.enums import LogLevel, Env, TokenType from ._skyflow_messages import SkyflowMessages from ._version import SDK_VERSION -from ._helpers import get_base_url, format_scope +from ._helpers import get_base_url, format_scope, is_valid_url from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_detected_entity_to_entity_info diff --git a/skyflow/utils/_helpers.py b/skyflow/utils/_helpers.py index 97eecabc..12ff1257 100644 --- a/skyflow/utils/_helpers.py +++ b/skyflow/utils/_helpers.py @@ -8,4 +8,11 @@ def get_base_url(url): def format_scope(scopes): if not scopes: return None - return " ".join([f"role:{scope}" for scope in scopes]) \ No newline at end of file + return " ".join([f"role:{scope}" for scope in scopes]) + +def is_valid_url(url): + try: + result = urlparse(url) + return all([result.scheme == "https", result.netloc]) + except Exception: + return False \ No newline at end of file diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 58067673..232bd8b0 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -4,6 +4,7 @@ error_prefix = f"Skyflow Python SDK {SDK_VERSION}" INFO = "INFO" +WARN = "WARN" ERROR = "ERROR" class SkyflowMessages: @@ -16,7 +17,7 @@ class ErrorCodes(Enum): REDACTION_WITH_TOKENS_NOT_SUPPORTED = 400 class Error(Enum): - GENERIC_API_ERROR = f"{error_prefix} Validation error. Invalid configuration. Please add a valid vault configuration." + GENERIC_API_ERROR = f"{error_prefix} API error. Error occurred." EMPTY_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id." INVALID_VAULT_ID = f"{error_prefix} Initialization failed. Invalid vault Id. Specify a valid vault Id as a string." @@ -42,11 +43,12 @@ class Error(Enum): EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Specify a valid file path." EMPTY_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Specify a valid file path." INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Expected file path to be a string." - INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a string." + INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a valid file path." EMPTY_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid token for {{}} with id {{}}.Specify a valid credentials token." EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token." INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." + EXPIRED_BEARER_TOKEN = f"{error_prefix} Initialization failed. Bearer token is invalid or expired." EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." @@ -73,6 +75,9 @@ class Error(Enum): RESPONSE_NOT_JSON = f"{error_prefix} Response {{}} is not valid JSON." API_ERROR = f"{error_prefix} Server returned status code {{}}" + INVALID_JSON_RESPONSE = f"{error_prefix} Invalid JSON response received." + UNKNOWN_ERROR_DEFAULT_MESSAGE = f"{error_prefix} An unknown error occurred." + INVALID_FILE_INPUT = f"{error_prefix} Validation error. Invalid file input. Specify a valid file input." INVALID_DETECT_ENTITIES_TYPE = f"{error_prefix} Validation error. Invalid type of detect entities. Specify detect entities as list of DetectEntities enum." INVALID_TYPE_FOR_DEFAULT_TOKEN_TYPE = f"{error_prefix} Validation error. Invalid type of default token type. Specify default token type as TokenType enum." @@ -86,14 +91,15 @@ class Error(Enum): INVALID_TABLE_NAME_IN_INSERT = f"{error_prefix} Validation error. Invalid table name in insert request. Specify a valid table name." INVALID_TYPE_OF_DATA_IN_INSERT = f"{error_prefix} Validation error. Invalid type of data in insert request. Specify data as a object array." EMPTY_DATA_IN_INSERT = f"{error_prefix} Validation error. Data array cannot be empty. Specify data in insert request." - INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. 'upsert' key cannot be empty in options. At least one object of table and column is required." + INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. Invalid 'upsert' value in options. Specify 'upsert' as a non-empty string containing the column name." INVALID_HOMOGENEOUS_TYPE = f"{error_prefix} Validation error. Invalid type of homogeneous. Specify homogeneous as a string." INVALID_TOKEN_MODE_TYPE = f"{error_prefix} Validation error. Invalid type of token mode. Specify token mode as a TokenMode enum." INVALID_RETURN_TOKENS_TYPE = f"{error_prefix} Validation error. Invalid type of return tokens. Specify return tokens as a boolean." INVALID_CONTINUE_ON_ERROR_TYPE = f"{error_prefix} Validation error. Invalid type of continue on error. Specify continue on error as a boolean." TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE = f"{error_prefix} Validation error. 'token_mode' wasn't specified. Set 'token_mode' to 'ENABLE' to insert tokens." INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT = f"{error_prefix} Validation error. 'token_mode' is set to 'ENABLE_STRICT', but some fields are missing tokens. Specify tokens for all fields." - NO_TOKENS_IN_INSERT = f"{error_prefix} Validation error. Tokens weren't specified for records while 'token_strict' was {{}}. Specify tokens." + MISMATCH_OF_FIELDS_AND_TOKENS = f"{error_prefix} Validation error. Keys for values and tokens are not matching. Ensure each values entry and its corresponding tokens entry have the same keys." + NO_TOKENS_IN_INSERT = f"{error_prefix} Validation error. Tokens weren't specified for records while 'token_mode' was {{}}. Specify tokens." BATCH_INSERT_FAILURE = f"{error_prefix} Insert operation failed." GET_FAILURE = f"{error_prefix} Get operation failed." HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT = f"{error_prefix} Validation error. Homogenous is not supported when upsert is passed." @@ -116,15 +122,16 @@ class Error(Enum): INVOKE_CONNECTION_FAILED = f"{error_prefix} Invoke Connection operation failed." INVALID_IDS_TYPE = f"{error_prefix} Validation error. 'ids' has a value of type {{}}. Specify 'ids' as list." - INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction' has a value of type {{}}. Specify 'redaction' as type Skyflow.RedactionType." - INVALID_COLUMN_NAME = f"{error_prefix} Validation error. 'column' has a value of type {{}}. Specify 'column' as a string." - INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. columnValues key has a value of type {{}}. Specify columnValues key as list." + INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction_type' has a value of type {{}}. Specify 'redaction_type' as type Skyflow.RedactionType." + INVALID_COLUMN_NAME = f"{error_prefix} Validation error. column_name has a value of type {{}}. Specify 'column' as a string." + INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. column_values key has a value of type {{}}. Specify column_values key as list." + INVALID_COLUMN_VALUES = f"{error_prefix} Validation error. column_values key is an empty list. Specify at least one column value when column_name is passed." INVALID_FIELDS_VALUE = f"{error_prefix} Validation error. fields key has a value of type{{}}. Specify fields key as list." - BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"${error_prefix} Validation error. Both offset and limit cannot be present at the same time" + BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"{error_prefix} Validation error. Both offset and limit cannot be present at the same time" INVALID_OFF_SET_VALUE = f"{error_prefix} Validation error. offset key has a value of type {{}}. Specify offset key as integer." INVALID_LIMIT_VALUE = f"{error_prefix} Validation error. limit key has a value of type {{}}. Specify limit key as integer." INVALID_DOWNLOAD_URL_VALUE = f"{error_prefix} Validation error. download_url key has a value of type {{}}. Specify download_url key as boolean." - REDACTION_WITH_TOKENS_NOT_SUPPORTED = f"{error_prefix} Validation error. 'redaction' can't be used when tokens are specified. Remove 'redaction' from payload if tokens are specified." + REDACTION_WITH_TOKENS_NOT_SUPPORTED = f"{error_prefix} Validation error. 'redaction_type' can't be used when tokens are specified. Remove 'redaction_type' from payload if tokens are specified." TOKENS_GET_COLUMN_NOT_SUPPORTED = f"{error_prefix} Validation error. Column name and/or column values can't be used when tokens are specified. Remove unique column values or tokens from the payload." BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED = f"{error_prefix} Validation error. Both Skyflow IDs and column details can't be specified. Either specify Skyflow IDs or unique column details." INVALID_ORDER_BY_VALUE = f"{error_prefix} Validation error. order_by key has a value of type {{}}. Specify order_by key as Skyflow.OrderBy" @@ -132,7 +139,7 @@ class Error(Enum): UPDATE_FIELD_KEY_ERROR = f"{error_prefix} Validation error. Fields are empty in an update payload. Specify at least one field." INVALID_FIELDS_TYPE = f"{error_prefix} Validation error. The 'data' key has a value of type {{}}. Specify 'data' as a dictionary." IDS_KEY_ERROR = f"{error_prefix} Validation error. 'ids' key is missing from the payload. Specify an 'ids' key." - INVALID_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. The 'data' field is invalid. Specify 'data' as a list of dictionaries containing 'token' and 'redaction'." + INVALID_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. The 'data' field is invalid. Specify 'data' as a list of dictionaries containing 'token' and 'redaction_type'." INVALID_DATA_FOR_DETOKENIZE = f"{error_prefix}" EMPTY_TOKENS_LIST_VALUE = f"{error_prefix} Validation error. Tokens are empty in detokenize payload. Specify at lease one token" INVALID_TOKEN_TYPE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens should be of type string." @@ -155,10 +162,13 @@ class Error(Enum): MISSING_CLIENT_ID = f"{error_prefix} Initialization failed. Unable to read client ID in credentials. Verify your client ID." MISSING_KEY_ID = f"{error_prefix} Initialization failed. Unable to read key ID in credentials. Verify your key ID." MISSING_TOKEN_URI = f"{error_prefix} Initialization failed. Unable to read token URI in credentials. Verify your token URI." + INVALID_TOKEN_URI = f"{error_prefix} Initialization failed. Invalid Skyflow credentials. The token URI must be a string and a valid URL." JWT_INVALID_FORMAT = f"{error_prefix} Initialization failed. Invalid private key format. Verify your credentials." JWT_DECODE_ERROR = f"{error_prefix} Validation error. Invalid access token. Verify your credentials." FILE_INVALID_JSON = f"{error_prefix} Initialization failed. File at {{}} is not in valid JSON format. Verify the file contents." INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV = f"{error_prefix} Validation error. Invalid JSON format in SKYFLOW_CREDENTIALS environment variable." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." INVALID_TEXT_IN_DEIDENTIFY= f"{error_prefix} Validation error. The text field is required and must be a non-empty string. Specify a valid text." INVALID_ENTITIES_IN_DEIDENTIFY= f"{error_prefix} Validation error. The entities field must be an array of DetectEntities enums. Specify a valid entities." @@ -280,7 +290,6 @@ class Info(Enum): VALIDATING_FILE_UPLOAD_REQUEST = f"{INFO}: [{error_prefix}] Validating file upload request." FILE_UPLOAD_REQUEST_RESOLVED = f"{INFO}: [{error_prefix}] File upload request resolved." FILE_UPLOAD_SUCCESS = f"{INFO}: [{error_prefix}] File uploaded successfully." - FILE_UPLOAD_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] File upload failed." INVOKE_CONNECTION_TRIGGERED = f"{INFO}: [{error_prefix}] Invoke connection method triggered." VALIDATING_INVOKE_CONNECTION_REQUEST = f"{INFO}: [{error_prefix}] Validating invoke connection request." @@ -310,6 +319,8 @@ class Info(Enum): DETECT_REQUEST_RESOLVED = f"{INFO}: [{error_prefix}] Detect request is resolved." class ErrorLogs(Enum): + INVALID_LOG_LEVEL = f"{ERROR}: [{error_prefix}] Invalid log level. Specify a valid log level." + INVALID_KEY = f"{ERROR}: [{error_prefix}] Invalid key {{}} in config." VAULTID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid vault config. Vault ID is required." EMPTY_VAULTID = f"{ERROR}: [{error_prefix}] Invalid vault config. Vault ID can not be empty." CLUSTER_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid vault config. Cluster ID is required." @@ -334,6 +345,8 @@ class ErrorLogs(Enum): KEY_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Key ID is required." TOKEN_URI_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Token URI is required." INVALID_TOKEN_URI = f"{ERROR}: [{error_prefix}] Invalid value for token URI in credentials." + FAILED_TO_GET_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Failed to generate bearer token." + UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN = f"{ERROR}: [{error_prefix}] Authorization failed while retrieving the bearer token." TABLE_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Table is required." @@ -348,13 +361,14 @@ class ErrorLogs(Enum): EMPTY_OR_NULL_VALUE_IN_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Value can not be null or empty in tokens for key {{}}." EMPTY_OR_NULL_KEY_IN_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Key can not be null or empty in tokens." MISMATCH_OF_FIELDS_AND_TOKENS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Keys for values and tokens are not matching." + FILE_UPLOAD_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] File upload failed." EMPTY_IDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Ids can not be empty." EMPTY_OR_NULL_ID_IN_IDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Id can not be null or empty in ids at index {{}}." TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokenization is not supported when redaction is applied." TOKENIZATION_SUPPORTED_ONLY_WITH_IDS=f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokenization is not supported when column name and values are passed." - TOKENS_NOT_ALLOWED_WITH_BYOT_DISABLE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are not allowed when token_strict is DISABLE." - INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT =f"{ERROR}: [{error_prefix}] Invalid {{}} request. For tokenStrict as ENABLE_STRICT, tokens should be passed for all fields." + TOKENS_NOT_ALLOWED_WITH_BYOT_DISABLE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are not allowed when token_mode is DISABLE." + INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT =f"{ERROR}: [{error_prefix}] Invalid {{}} request. For token_mode as ENABLE_STRICT, tokens should be passed for all fields." TOKENS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Tokens are required." EMPTY_FIELDS = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Fields can not be empty." EMPTY_OFFSET = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Offset ca not be empty." @@ -365,7 +379,7 @@ class ErrorLogs(Enum): SKYFLOW_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id is required." EMPTY_SKYFLOW_ID = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id can not be empty." - COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. ColumnValues are required." + COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. column_values are required." EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Column group can not be null or empty in column values at index %s2." EMPTY_QUERY= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Query can not be empty." @@ -388,6 +402,7 @@ class ErrorLogs(Enum): SAVING_DEIDENTIFY_FILE_FAILED = f"{ERROR}: [{error_prefix}] Error while saving deidentified file to output directory." REIDENTIFY_TEXT_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Reidentify text resulted in failure." DETECT_FILE_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Deidentify file resulted in failure." + EMPTY_FILE_COLUMN_NAME = f"{ERROR}: [{error_prefix}] Empty column name in FILE_UPLOAD" class Interface(Enum): INSERT = "INSERT" @@ -402,7 +417,18 @@ class HttpStatus(Enum): BAD_REQUEST = "Bad Request" class Warning(Enum): - WARNING_MESSAGE = "WARNING MESSAGE" + DETOKENIZE_REDACTION_KEY_DEPRECATED = ( + f"{WARN}: [{error_prefix}] 'redaction' key in detokenize data is deprecated and will be removed in a future version. Use 'redaction_type' instead." + ) + UPDATE_LOG_LEVEL_DEPRECATED = ( + f"{WARN}: [{error_prefix}] Skyflow.update_log_level() is deprecated. " + "Use Skyflow.set_log_level() instead." + ) + FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED = ( + f"{WARN}: [{error_prefix}] FileUploadRequest: argument order changed. " + "Old positional order: (table, skyflow_id, column_name). " + "New order: FileUploadRequest(table, column_name=..., skyflow_id=...)." + ) diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index d8eedca2..7ed7bc99 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -20,7 +20,9 @@ from skyflow.vault.detect import DeidentifyTextResponse, ReidentifyTextResponse from skyflow.vault.detect import EntityInfo, TextIndex from . import SkyflowMessages, SDK_VERSION -from .constants import PROTOCOL +from .constants import (PROTOCOL, HttpHeader, ApiKey, ContentType as ContentTypeConstants, + EncodingType, BooleanString, ResponseField, CredentialField, SdkPrefix, + SdkMetricsKey, ErrorDefaults, HttpStatusCode) from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params @@ -30,9 +32,9 @@ invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): - if config_level_creds: + if config_level_creds is not None: return config_level_creds - if common_skyflow_creds: + if common_skyflow_creds is not None: return common_skyflow_creds dotenv_path = dotenv.find_dotenv(usecwd=True) if dotenv_path: @@ -44,7 +46,7 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False api_key_pattern = re.compile(r'^sky-[a-zA-Z0-9]{5}-[a-fA-F0-9]{32}$') @@ -70,9 +72,9 @@ def parse_path_params(url, path_params): return result -def to_lowercase_keys(dict): +def to_lowercase_keys(data): result = {} - for key, value in dict.items(): + for key, value in data.items(): result[key.lower()] = value return result @@ -96,31 +98,45 @@ def convert_detected_entity_to_entity_info(detected_entity): def construct_invoke_connection_request(request, connection_url, logger) -> PreparedRequest: url = parse_path_params(connection_url.rstrip('/'), request.path_params) - try: - if isinstance(request.headers, dict): - header = to_lowercase_keys(json.loads( - json.dumps(request.headers))) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + header = None + content_type = None - if not 'Content-Type'.lower() in header: - header['content-type'] = ContentType.JSON.value + if request.headers is not None: + try: + if isinstance(request.headers, dict): + header = to_lowercase_keys(json.loads( + json.dumps(request.headers))) + + content_type = header.get(HttpHeader.CONTENT_TYPE_LOWERCASE) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - try: - if isinstance(request.body, dict): - json_data, files = get_data_from_content_type( - request.body, header["content-type"] - ) - else: + json_data = None + files = {} + + if request.body is not None: + try: + if isinstance(request.body, dict): + json_data, files = get_data_from_content_type( + request.body, content_type + ) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) - except Exception as e: - raise SkyflowError( SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + + if files and header and content_type == ContentType.FORMDATA.value: + header.pop(HttpHeader.CONTENT_TYPE_LOWERCASE, None) validate_invoke_connection_params(logger, request.query_params, request.path_params) - if not hasattr(request.method, 'value'): + if not hasattr(request.method, ResponseField.VALUE): raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_METHOD.value, invalid_input_error_code) try: @@ -166,21 +182,59 @@ def render_key(parents): def get_data_from_content_type(data, content_type): converted_data = data files = {} + if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) elif content_type == ContentType.FORMDATA.value: - converted_data = r_urlencode(list(), dict(), data) - files = {(None, None)} + converted_data = None + files = {} + for key, value in data.items(): + files[key] = (None, str(value)) elif content_type == ContentType.JSON.value: converted_data = json.dumps(data) + elif content_type == ContentType.XML.value or content_type == 'application/xml' or content_type == 'text/xml': + if isinstance(data, dict): + converted_data = dict_to_xml(data) + else: + converted_data = str(data) + elif content_type == ContentType.HTML.value or content_type == 'text/html': + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) + else: + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) return converted_data, files +def dict_to_xml(data, root_tag='root'): + def build_xml(d, tag='item'): + if isinstance(d, dict): + xml_parts = [f'<{tag}>'] + for key, value in d.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + elif isinstance(d, list): + return ''.join([build_xml(item, tag) for item in d]) + else: + return f'<{tag}>{d}' + + xml_parts = [f'<{root_tag}>'] + for key, value in data.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + + +_CACHED_METRICS: dict = {} _CACHED_METRICS: dict = {} def get_metrics(): - global _CACHED_METRICS if _CACHED_METRICS: return _CACHED_METRICS @@ -199,12 +253,12 @@ def get_metrics(): except Exception: sdk_runtime_details = "" - _CACHED_METRICS = { - 'sdk_name_version': "skyflow-python@" + SDK_VERSION, - 'sdk_client_device_model': sdk_client_device_model, - 'sdk_client_os_details': sdk_client_os_details, - 'sdk_runtime_details': "Python " + sdk_runtime_details, - } + _CACHED_METRICS.update({ + SdkMetricsKey.SDK_NAME_VERSION: SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION, + SdkMetricsKey.SDK_CLIENT_DEVICE_MODEL: sdk_client_device_model, + SdkMetricsKey.SDK_CLIENT_OS_DETAILS: sdk_client_os_details, + SdkMetricsKey.SDK_RUNTIME_DETAILS: SdkPrefix.PYTHON_RUNTIME + sdk_runtime_details, + }) return _CACHED_METRICS def parse_insert_response(api_response, continue_on_error): @@ -212,30 +266,30 @@ def parse_insert_response(api_response, continue_on_error): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) inserted_fields = [] errors = [] insert_response = InsertResponse() if continue_on_error: for idx, response in enumerate(api_response_data.responses): - if response['Status'] == 200: - body = response['Body'] - if 'records' in body: - for record in body['records']: + if response[ResponseField.STATUS] == HttpStatusCode.OK: + body = response[ResponseField.BODY] + if ResponseField.RECORDS in body: + for record in body[ResponseField.RECORDS]: inserted_field = { - 'skyflow_id': record['skyflow_id'], - 'request_index': idx + ResponseField.SKYFLOW_ID: record[ResponseField.SKYFLOW_ID], + ResponseField.REQUEST_INDEX: idx } - if 'tokens' in record: - inserted_field.update(record['tokens']) + if ResponseField.TOKENS in record: + inserted_field.update(record[ResponseField.TOKENS]) inserted_fields.append(inserted_field) - elif response['Status'] == 400: + elif response[ResponseField.STATUS] == HttpStatusCode.BAD_REQUEST: error = { - 'request_index': idx, - 'request_id': request_id, - 'error': response['Body']['error'], - 'http_code': response['Status'], + ResponseField.REQUEST_INDEX: idx, + ResponseField.REQUEST_ID: request_id, + ResponseField.ERROR: response[ResponseField.BODY][ResponseField.ERROR], + ResponseField.HTTP_CODE: response[ResponseField.STATUS], } errors.append(error) @@ -244,7 +298,7 @@ def parse_insert_response(api_response, continue_on_error): else: for record in api_response_data.records: field_data = { - 'skyflow_id': record.skyflow_id + ResponseField.SKYFLOW_ID: record.skyflow_id } if record.tokens: @@ -259,7 +313,7 @@ def parse_insert_response(api_response, continue_on_error): def parse_update_record_response(api_response: V1UpdateRecordResponse): update_response = UpdateResponse() updated_field = dict() - updated_field['skyflow_id'] = api_response.skyflow_id + updated_field[ResponseField.SKYFLOW_ID] = api_response.skyflow_id if api_response.tokens is not None: updated_field.update(api_response.tokens) @@ -289,23 +343,23 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) detokenized_fields = [] errors = [] for record in api_response_data.records: if record.error: errors.append({ - "token": record.token, - "error": record.error, - "request_id": request_id + ResponseField.TOKEN: record.token, + ResponseField.ERROR: record.error, + ResponseField.REQUEST_ID: request_id }) else: value_type = record.value_type if record.value_type else None detokenized_fields.append({ - "token": record.token, - "value": record.value, - "type": value_type + ResponseField.TOKEN: record.token, + ResponseField.VALUE: record.value, + ResponseField.TYPE: value_type }) detokenized_fields = detokenized_fields @@ -318,7 +372,7 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): def parse_tokenize_response(api_response: V1TokenizeResponse): tokenize_response = TokenizeResponse() - tokenized_fields = [{"token": record.token} for record in api_response.records] + tokenized_fields = [{ResponseField.TOKEN: record.token} for record in api_response.records] tokenize_response.tokenized_fields = tokenized_fields @@ -330,7 +384,7 @@ def parse_query_response(api_response: V1GetQueryResponse): for record in api_response.records: field_object = { **record.fields, - "tokenized_data": {} + ResponseField.TOKENIZED_DATA: {} } fields.append(field_object) query_response.fields = fields @@ -340,40 +394,59 @@ def parse_invoke_connection_response(api_response: requests.Response): status_code = api_response.status_code content = api_response.content if isinstance(content, bytes): - content = content.decode('utf-8') + content = content.decode(EncodingType.UTF_8) + try: api_response.raise_for_status() - try: - data = json.loads(content) - metadata = {} - if 'x-request-id' in api_response.headers: - metadata['request_id'] = api_response.headers['x-request-id'] - - return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) - except Exception as e: - raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code) + + content_type = api_response.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE, '').lower() + + if ContentTypeConstants.APPLICATION_JSON in content_type or not content_type: + try: + data = json.loads(content) + except json.JSONDecodeError: + data = content + else: + data = content + + metadata = {} + if HttpHeader.X_REQUEST_ID in api_response.headers: + metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID] + + return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) + except HTTPError: message = SkyflowMessages.Error.API_ERROR.value.format(status_code) + request_id = api_response.headers.get(HttpHeader.X_REQUEST_ID) + try: - error_response = json.loads(content) - request_id = api_response.headers['x-request-id'] - error_from_client = api_response.headers.get('error-from-client') - - status_code = error_response.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = error_response.get('error', {}).get('http_status') - grpc_code = error_response.get('error', {}).get('grpc_code') - details = error_response.get('error', {}).get('details') - message = error_response.get('error', {}).get('message', "An unknown error occurred.") - + error_response = json.loads(content) + error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) + + http_status = None + grpc_code = None + details = None + + error_obj = error_response.get(ResponseField.ERROR) if isinstance(error_response, dict) else None + if isinstance(error_obj, dict): + status_code = error_obj.get(ResponseField.HTTP_CODE, status_code) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS) + message = error_obj.get(ResponseField.MESSAGE, message) + elif isinstance(error_obj, str) and error_obj: + message = error_obj + if error_from_client is not None: - if details is None: details = [] - error_from_client_bool = error_from_client.lower() == 'true' - details.append({'error_from_client': error_from_client_bool}) + if details is None: + details = [] + error_from_client_bool = error_from_client.lower() == BooleanString.TRUE + details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) + except json.JSONDecodeError: - message = SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content) - raise SkyflowError(message, status_code) + raise SkyflowError(message, status_code, request_id) def parse_deidentify_text_response(api_response: DeidentifyStringResponse): entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities] @@ -391,51 +464,79 @@ def log_and_reject_error(description, status_code, request_id, http_status=None, raise SkyflowError(description, status_code, request_id, grpc_code, http_status, details) def handle_exception(error, logger): - # handle invalid cluster ID error scenario - if (isinstance(error, httpx.ConnectError)): - handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) + if isinstance(error, httpx.ConnectError): + description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=logger) + return + + if not hasattr(error, 'headers') or not hasattr(error, 'body') or error.headers is None or error.body is None: + description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=logger) + return - request_id = error.headers.get('x-request-id', 'unknown-request-id') - content_type = error.headers.get('content-type') + request_id = error.headers.get(HttpHeader.X_REQUEST_ID, ErrorDefaults.UNKNOWN_REQUEST_ID) + content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) data = error.body if content_type: - if 'application/json' in content_type: + if ContentTypeConstants.APPLICATION_JSON in content_type: handle_json_error(error, data, request_id, logger) - elif 'text/plain' in content_type: + elif ContentTypeConstants.TEXT_PLAIN in content_type: handle_text_error(error, data, request_id, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) def handle_json_error(err, data, request_id, logger): try: - if isinstance(data, dict): # If data is already a dict + if isinstance(data, dict): description = data elif isinstance(data, ErrorResponse): description = data.dict() else: description = json.loads(data) - status_code = description.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = description.get('error', {}).get('http_status') - grpc_code = description.get('error', {}).get('grpc_code') - details = description.get('error', {}).get('details', []) - description_message = description.get('error', {}).get('message', "An unknown error occurred.") - log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger = logger) + if ResponseField.ERROR in description: + error_obj = description.get(ResponseField.ERROR, {}) + status_code = error_obj.get(ResponseField.HTTP_CODE, HttpStatusCode.INTERNAL_SERVER_ERROR) + http_status = error_obj.get(ResponseField.HTTP_STATUS) + grpc_code = error_obj.get(ResponseField.GRPC_CODE) + details = error_obj.get(ResponseField.DETAILS, []) + description_message = error_obj.get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + elif ResponseField.RESPONSES in description: + responses = description.get(ResponseField.RESPONSES, []) + messages = [] + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + for resp in responses: + resp_status = resp.get(ResponseField.STATUS, HttpStatusCode.INTERNAL_SERVER_ERROR) + resp_body = resp.get(ResponseField.BODY, {}) + if isinstance(resp_status, int) and resp_status >= HttpStatusCode.BAD_REQUEST: + status_code = resp_status + error_msg = resp_body.get(ResponseField.ERROR) + if error_msg: + messages.append(str(error_msg)) + description_message = '; '.join(messages) if messages else SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + http_status = None + grpc_code = None + details = [] + else: + status_code = HttpStatusCode.INTERNAL_SERVER_ERROR + http_status = None + grpc_code = None + details = [] + description_message = SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value + + log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger=logger) except json.JSONDecodeError: - log_and_reject_error("Invalid JSON response received.", err, request_id, logger = logger) + log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger=logger) def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) -def handle_generic_error(err, request_id, logger): - handle_generic_error(err, request_id, err.status, logger = logger) - -def handle_generic_error(err, request_id, status, logger): - description = SkyflowMessages.Error.GENERIC_API_ERROR.value - log_and_reject_error(description, status, request_id, logger = logger) +def handle_generic_error_with_status(err, request_id, status, logger): + description = str(err) if err else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, status, request_id, logger=logger) def encode_column_values(get_request): encoded_column_values = list() diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index af153938..949d3423 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.2' \ No newline at end of file +SDK_VERSION = '2.0.2.dev0+e0253e9' diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 05e520d6..05d28380 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -3,3 +3,289 @@ SKY_META_DATA_HEADER='sky-metadata' CTX_KEY_REGEX=r'^[a-zA-Z0-9_]+$' +class SKYFLOW: + SKYFLOW_ID = 'skyflowId' + X_SKYFLOW_AUTHORIZATION = 'x-skyflow-authorization' + + +class HttpHeader: + CONTENT_TYPE = 'Content-Type' + CONTENT_TYPE_LOWERCASE = 'content-type' + X_REQUEST_ID = 'x-request-id' + ERROR_FROM_CLIENT = 'error-from-client' + AUTHORIZATION = 'Authorization' + X_SKYFLOW_AUTHORIZATION_HEADER = 'X-Skyflow-Authorization' + + +class HttpStatusCode: + OK = 200 + BAD_REQUEST = 400 + UNAUTHORIZED = 401 + INTERNAL_SERVER_ERROR = 500 + + +class ContentType: + APPLICATION_JSON = 'application/json' + APPLICATION_X_WWW_FORM_URLENCODED = 'application/x-www-form-urlencoded' + TEXT_PLAIN = 'text/plain' + + +class DetectStatus: + IN_PROGRESS = 'IN_PROGRESS' + SUCCESS = 'SUCCESS' + FAILED = 'FAILED' + UNKNOWN = 'UNKNOWN' + +class Detect: + WAIT_TIME = 64 + +class FileExtension: + JSON = 'json' + MP3 = 'mp3' + WAV = 'wav' + PDF = 'pdf' + TXT = 'txt' + DOC = 'doc' + DOCX = 'docx' + JPG = 'jpg' + JPEG = 'jpeg' + PNG = 'png' + BMP = 'bmp' + TIF = 'tif' + TIFF = 'tiff' + PPT = 'ppt' + PPTX = 'pptx' + CSV = 'csv' + XLS = 'xls' + XLSX = 'xlsx' + XML = 'xml' + + +class FileProcessing: + PROCESSED_PREFIX = 'processed-' + DEIDENTIFIED_PREFIX = 'deidentified.' + ENTITIES = 'entities' + + +class EncodingType: + UTF8 = 'utf8' + UTF_8 = 'utf-8' + BASE64 = 'base64' + BINARY = 'binary' + + +class JWT: + ALGORITHM_RS256 = 'RS256' + GRANT_TYPE_JWT_BEARER = 'urn:ietf:params:oauth:grant-type:jwt-bearer' + ISSUER_SDK = 'sdk' + SIGNED_TOKEN_PREFIX = 'signed_token_' + ROLE_PREFIX = 'role:' + + +class ApiKey: + SKY_PREFIX = 'sky-' + LENGTH = 42 + + +class UrlProtocol: + HTTPS = 'https' + HTTP = 'http' + + +class BooleanString: + TRUE = 'true' + FALSE = 'false' + + +class ResponseField: + STATUS = 'Status' + BODY = 'Body' + RECORDS = 'records' + TOKENS = 'tokens' + ERROR = 'error' + SKYFLOW_ID = 'skyflow_id' + REQUEST_INDEX = 'request_index' + REQUEST_ID = 'request_id' + HTTP_CODE = 'http_code' + HTTP_STATUS = 'http_status' + GRPC_CODE = 'grpc_code' + DETAILS = 'details' + MESSAGE = 'message' + ERROR_FROM_CLIENT = 'error_from_client' + TOKEN = 'token' + VALUE = 'value' + TYPE = 'type' + TOKENIZED_DATA = 'tokenized_data' + SIGNED_TOKEN = 'signed_token' + RESPONSES = 'responses' + + +class CredentialField: + PRIVATE_KEY = 'privateKey' + CLIENT_ID = 'clientID' + KEY_ID = 'keyID' + TOKEN_URI = 'tokenURI' + TOKEN_URI_OPTION = 'token_uri' + CLIENT_NAME = 'clientName' + CREDENTIALS_STRING = 'credentials_string' + API_KEY = 'api_key' + TOKEN = 'token' + PATH = 'path' + CONTEXT = 'context' + ROLES = 'roles' + + +class JwtField: + ISS = 'iss' + KEY = 'key' + AUD = 'aud' + SUB = 'sub' + EXP = 'exp' + CTX = 'ctx' + TOK = 'tok' + IAT = 'iat' + + +class OptionField: + ROLE_IDS = 'role_ids' + DATA_TOKENS = 'data_tokens' + TIME_TO_LIVE = 'time_to_live' + ROLES = 'roles' + CTX = 'ctx' + VAULT_ID = 'vault_id' + CONNECTION_ID = 'connection_id' + CONNECTION_URL = 'connection_url' + VAULT_CLIENT = 'vault_client' + VAULT_CONTROLLER = 'vault_controller' + DETECT_CONTROLLER = 'detect_controller' + CONTROLLER = 'controller' + VERIFY_SIGNATURE = 'verify_signature' + VERIFY_AUD = 'verify_aud' + + +class ConfigField: + CREDENTIALS = 'credentials' + CLUSTER_ID = 'cluster_id' + ENV = 'env' + VAULT_ID = 'vault_id' + + +class RequestParameter: + VALUE = 'value' + COLUMN_GROUP = 'column_group' + REDACTION = 'redaction' + REDACTION_TYPE = 'redaction_type' + + +class FileUploadField: + TABLE = 'table' + SKYFLOW_ID = 'skyflow_id' + COLUMN_NAME = 'column_name' + FILE_PATH = 'file_path' + BASE64 = 'base64' + FILE_OBJECT = 'file_object' + FILE_NAME = 'file_name' + FILE = 'file' + NAME = 'name' + + +class DeidentifyFileRequestField: + ENTITIES = 'entities' + ALLOW_REGEX_LIST = 'allow_regex_list' + RESTRICT_REGEX_LIST = 'restrict_regex_list' + OUTPUT_PROCESSED_IMAGE = 'output_processed_image' + OUTPUT_OCR_TEXT = 'output_ocr_text' + MASKING_METHOD = 'masking_method' + PIXEL_DENSITY = 'pixel_density' + DENSITY = 'density' + MAX_RESOLUTION = 'max_resolution' + OUTPUT_PROCESSED_AUDIO = 'output_processed_audio' + OUTPUT_TRANSCRIPTION = 'output_transcription' + BLEEP = 'bleep' + OUTPUT_DIRECTORY = 'output_directory' + WAIT_TIME = 'wait_time' + + +class DeidentifyField: + TEXT = 'text' + ENTITY_TYPES = 'entity_types' + TOKEN_TYPE = 'token_type' + ALLOW_REGEX = 'allow_regex' + RESTRICT_REGEX = 'restrict_regex' + TRANSFORMATIONS = 'transformations' + FORMAT = 'format' + OUTPUT = 'output' + STATUS = 'status' + RUN_ID = 'run_id' + WORD_CHARACTER_COUNT = 'word_character_count' + WORD_COUNT = 'word_count' + CHARACTER_COUNT = 'character_count' + SIZE = 'size' + DURATION = 'duration' + PAGES = 'pages' + SLIDES = 'slides' + PROCESSED_FILE = 'processed_file' + PROCESSED_FILE_TYPE = 'processed_file_type' + PROCESSED_FILE_EXTENSION = 'processed_file_extension' + REDACTED_FILE = 'redacted_file' + SHIFT_DATES = 'shift_dates' + DEFAULT = 'default' + ENTITY_UNQ_COUNTER = 'entity_unq_counter' + ENTITY_UNIQUE_COUNTER = 'entity_unique_counter' + ENTITY_ONLY = 'entity_only' + VAULT_TOKEN = 'vault_token' + ENTITIES = 'entities' + MAX_DAYS = 'max_days' + MIN_DAYS = 'min_days' + MAX = 'max' + MIN = 'min' + FILE = 'file' + TYPE = 'type' + EXTENSION = 'extension' + IN_PROGRESS = 'IN_PROGRESS' + REQUEST_OPTIONS = 'request_options' + BLEEP_GAIN = 'bleep_gain' + BLEEP_FREQUENCY = 'bleep_frequency' + BLEEP_START_PADDING = 'bleep_start_padding' + BLEEP_STOP_PADDING = 'bleep_stop_padding' + DENSITY = 'density' + TOKEN_FORMAT = 'token_format' + PROCESSED_FILE_RESPONSE_KEY = 'processedFile' + PROCESSED_FILE_TYPE_RESPONSE_KEY = 'processedFileType' + PROCESSED_FILE_EXTENSION_RESPONSE_KEY = 'processedFileExtension' + + +class RequestOperation: + INSERT = 'INSERT' + DELETE = 'DELETE' + GET = 'GET' + UPDATE = 'UPDATE' + QUERY = 'QUERY' + TOKENIZE = 'TOKENIZE' + DETOKENIZE = 'DETOKENIZE' + FILE_UPLOAD = 'FILE_UPLOAD' + + +class ConfigType: + VAULT = 'vault' + CONNECTION = 'connection' + + +class SqlCommand: + SELECT = 'SELECT' + + +class SdkPrefix: + SKYFLOW_PYTHON = 'skyflow-python@' + PYTHON_RUNTIME = 'Python ' + + +class SdkMetricsKey: + SDK_NAME_VERSION = 'sdk_name_version' + SDK_CLIENT_DEVICE_MODEL = 'sdk_client_device_model' + SDK_CLIENT_OS_DETAILS = 'sdk_client_os_details' + SDK_RUNTIME_DETAILS = 'sdk_runtime_details' + + +class ErrorDefaults: + UNKNOWN_REQUEST_ID = 'unknown-request-id' diff --git a/skyflow/utils/enums/content_types.py b/skyflow/utils/enums/content_types.py index 362c286a..f2db5b92 100644 --- a/skyflow/utils/enums/content_types.py +++ b/skyflow/utils/enums/content_types.py @@ -5,4 +5,5 @@ class ContentType(Enum): PLAINTEXT = 'text/plain' XML = 'text/xml' URLENCODED = 'application/x-www-form-urlencoded' - FORMDATA = 'multipart/form-data' \ No newline at end of file + FORMDATA = 'multipart/form-data' + HTML = 'text/html' \ No newline at end of file diff --git a/skyflow/utils/enums/detect_output_transcriptions.py b/skyflow/utils/enums/detect_output_transcriptions.py index 4e14f911..a398a3d8 100644 --- a/skyflow/utils/enums/detect_output_transcriptions.py +++ b/skyflow/utils/enums/detect_output_transcriptions.py @@ -4,4 +4,5 @@ class DetectOutputTranscriptions(Enum): DIARIZED_TRANSCRIPTION = "diarized_transcription" MEDICAL_DIARIZED_TRANSCRIPTION = "medical_diarized_transcription" MEDICAL_TRANSCRIPTION = "medical_transcription" - TRANSCRIPTION = "transcription" \ No newline at end of file + TRANSCRIPTION = "transcription" + PLAINTEXT_TRANSCRIPTION = "plaintext_transcription" \ No newline at end of file diff --git a/skyflow/utils/logger/__init__.py b/skyflow/utils/logger/__init__.py index 2993b8fc..bce55608 100644 --- a/skyflow/utils/logger/__init__.py +++ b/skyflow/utils/logger/__init__.py @@ -1,2 +1,2 @@ from ._logger import Logger -from ._log_helpers import log_error, log_info, log_error_log \ No newline at end of file +from ._log_helpers import log_error, log_info, log_warn, log_error_log, set_active_log_level \ No newline at end of file diff --git a/skyflow/utils/logger/_log_helpers.py b/skyflow/utils/logger/_log_helpers.py index fdb11ea9..1343b55f 100644 --- a/skyflow/utils/logger/_log_helpers.py +++ b/skyflow/utils/logger/_log_helpers.py @@ -1,5 +1,13 @@ from ..enums import LogLevel from . import Logger +from ..constants import ResponseField + +_active_log_level = LogLevel.ERROR + + +def set_active_log_level(level): + global _active_log_level + _active_log_level = level def log_info(message, logger = None): @@ -8,6 +16,11 @@ def log_info(message, logger = None): logger.info(message) +def log_warn(message, logger=None): + if not logger: + logger = Logger(_active_log_level) + logger.warn(message) + def log_error_log(message, logger=None): if not logger: logger = Logger(LogLevel.ERROR) @@ -18,17 +31,17 @@ def log_error(message, http_code, request_id=None, grpc_code=None, http_status=N logger = Logger(LogLevel.ERROR) log_data = { - 'http_code': http_code, - 'message': message + ResponseField.HTTP_CODE: http_code, + ResponseField.MESSAGE: message } if grpc_code is not None: - log_data['grpc_code'] = grpc_code + log_data[ResponseField.GRPC_CODE] = grpc_code if http_status is not None: - log_data['http_status'] = http_status + log_data[ResponseField.HTTP_STATUS] = http_status if request_id is not None: - log_data['request_id'] = request_id + log_data[ResponseField.REQUEST_ID] = request_id if details is not None: - log_data['details'] = details + log_data[ResponseField.DETAILS] = details logger.error(log_data) \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index acca531f..42abe188 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -6,62 +6,83 @@ MaskingMethod from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages -from skyflow.utils.logger import log_info, log_error_log +from skyflow.utils.constants import ( + ApiKey, ResponseField, RequestParameter, + FileUploadField, + DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField, Detect +) +from skyflow.utils.logger import log_info, log_warn, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest from skyflow.vault.detect._file_input import FileInput - -valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] -valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] -valid_credentials_keys = ["path", "roles", "context", "token", "credentials_string"] +from skyflow.utils._helpers import is_valid_url + +valid_vault_config_keys = [ + ConfigField.VAULT_ID, + ConfigField.CLUSTER_ID, + ConfigField.CREDENTIALS, + ConfigField.ENV +] +valid_connection_config_keys = [ + OptionField.CONNECTION_ID, + OptionField.CONNECTION_URL, + ConfigField.CREDENTIALS +] +valid_credentials_keys = [ + CredentialField.PATH, + CredentialField.ROLES, + CredentialField.CONTEXT, + CredentialField.TOKEN, + CredentialField.CREDENTIALS_STRING +] invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def validate_required_field(logger, config, field_name, expected_type, empty_error, invalid_error): field_value = config.get(field_name) if field_name not in config or not isinstance(field_value, expected_type): - if field_name == "vault_id": - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) - if field_name == "cluster_id": - logger.error(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value) - if field_name == "connection_id": - logger.error(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value) - if field_name == "connection_url": - logger.error(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value) + if field_name == ConfigField.VAULT_ID: + log_error_log(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value, logger) + if field_name == ConfigField.CLUSTER_ID: + log_error_log(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value, logger) + if field_name == OptionField.CONNECTION_ID: + log_error_log(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value, logger) + if field_name == OptionField.CONNECTION_URL: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value, logger) raise SkyflowError(invalid_error, invalid_input_error_code) if isinstance(field_value, str) and not field_value.strip(): - if field_name == "vault_id": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value) - if field_name == "cluster_id": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value) - if field_name == "connection_id": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value) - if field_name == "connection_url": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value) - if field_name == "path": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value) - if field_name == "credentials_string": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value) - if field_name == "token": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value) - if field_name == "api_key": - logger.error(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value) + if field_name == ConfigField.VAULT_ID: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value, logger) + if field_name == ConfigField.CLUSTER_ID: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value, logger) + if field_name == OptionField.CONNECTION_ID: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value, logger) + if field_name == OptionField.CONNECTION_URL: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value, logger) + if field_name == CredentialField.PATH: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value, logger) + if field_name == CredentialField.CREDENTIALS_STRING: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value, logger) + if field_name == CredentialField.TOKEN: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value, logger) + if field_name == CredentialField.API_KEY: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value, logger) raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if not api_key.startswith('sky-'): + if not api_key.startswith(ApiKey.SKY_PREFIX): log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger=logger) return False - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False return True def validate_credentials(logger, credentials, config_id_type=None, config_id=None): - key_present = [k for k in ["path", "token", "credentials_string", "api_key"] if credentials.get(k)] + key_present = [k for k in [CredentialField.PATH, CredentialField.TOKEN, CredentialField.CREDENTIALS_STRING, CredentialField.API_KEY] if credentials.get(k)] if len(key_present) == 0: error_message = ( @@ -69,6 +90,7 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) elif len(key_present) > 1: error_message = ( @@ -76,79 +98,90 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non if config_id_type and config_id else SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value ) + log_error_log(error_message, logger) raise SkyflowError(error_message, invalid_input_error_code) - if "roles" in credentials: + if CredentialField.ROLES in credentials: validate_required_field( - logger, credentials, "roles", list, + logger, credentials, CredentialField.ROLES, list, SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE.value, SkyflowMessages.Error.EMPTY_ROLES_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_ROLES.value ) - if "context" in credentials: + if CredentialField.CONTEXT in credentials: validate_required_field( - logger, credentials, "context", str, + logger, credentials, CredentialField.CONTEXT, str, SkyflowMessages.Error.EMPTY_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CONTEXT.value, SkyflowMessages.Error.INVALID_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CONTEXT.value ) - if "credentials_string" in credentials: + if CredentialField.CREDENTIALS_STRING in credentials: validate_required_field( - logger, credentials, "credentials_string", str, + logger, credentials, CredentialField.CREDENTIALS_STRING, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING.value, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value ) - elif "path" in credentials: + elif CredentialField.PATH in credentials: validate_required_field( - logger, credentials, "path", str, + logger, credentials, CredentialField.PATH, str, SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH.value, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value ) - elif "token" in credentials: + elif CredentialField.TOKEN in credentials: validate_required_field( - logger, credentials, "token", str, + logger, credentials, CredentialField.TOKEN, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) - if is_expired(credentials.get("token"), logger): + if is_expired(credentials.get(CredentialField.TOKEN), logger): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value, logger) raise SkyflowError( - SkyflowMessages.Error.EXPIRED_TOKEN.value - if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_TOKEN.value, + SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value + if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value, invalid_input_error_code ) - elif "api_key" in credentials: + elif CredentialField.API_KEY in credentials: validate_required_field( - logger, credentials, "api_key", str, + logger, credentials, CredentialField.API_KEY, str, SkyflowMessages.Error.EMPTY_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_API_KEY.value, SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value ) - if not validate_api_key(credentials.get("api_key"), logger): + if not validate_api_key(credentials.get(CredentialField.API_KEY), logger): raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) + + if CredentialField.TOKEN_URI_OPTION in credentials: + token_uri = credentials.get(CredentialField.TOKEN_URI_OPTION) + if ( + token_uri is None + or not isinstance(token_uri, str) + or not is_valid_url(token_uri) + ): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_TOKEN_URI.value, logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_URI.value, invalid_input_error_code) def validate_log_level(logger, log_level): if not isinstance(log_level, LogLevel): - raise SkyflowError( SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) - - if log_level is None: - raise SkyflowError(SkyflowMessages.Error.EMPTY_LOG_LEVEL.value, invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.INVALID_LOG_LEVEL.value, logger) + raise SkyflowError(SkyflowMessages.Error.INVALID_LOG_LEVEL.value, invalid_input_error_code) def validate_keys(logger, config, config_keys): for key in config.keys(): if key not in config_keys: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_KEY.value.format(key), logger) raise SkyflowError(SkyflowMessages.Error.INVALID_KEY.value.format(key), invalid_input_error_code) def validate_vault_config(logger, config): @@ -157,28 +190,28 @@ def validate_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) # Validate cluster_id (string, not empty) validate_required_field( - logger, config, "cluster_id", str, + logger, config, ConfigField.CLUSTER_ID, str, SkyflowMessages.Error.EMPTY_CLUSTER_ID.value.format(vault_id), SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id) ) # Validate credentials (dict, not empty) - if "credentials" in config and not config.get("credentials"): - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS in config and not config.get(ConfigField.CREDENTIALS): + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - if "credentials" in config and config.get("credentials"): - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + if ConfigField.CREDENTIALS in config and config.get(ConfigField.CREDENTIALS): + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) # Validate env (optional, should be one of LogLevel values) - if "env" in config and config.get("env") not in Env: - logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: + log_error_log(SkyflowMessages.ErrorLogs.ENV_IS_REQUIRED.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) return True @@ -189,23 +222,23 @@ def validate_update_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) - if "cluster_id" in config and not config.get("cluster_id"): + if ConfigField.CLUSTER_ID in config and not config.get(ConfigField.CLUSTER_ID): raise SkyflowError(SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id), invalid_input_error_code) - if "env" in config and config.get("env") not in Env: + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) return True @@ -214,23 +247,23 @@ def validate_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id" , str, + logger, config, OptionField.CONNECTION_ID , str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) return True @@ -239,193 +272,218 @@ def validate_update_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id", str, + logger, config, OptionField.CONNECTION_ID, str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials")) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS)) return True def validate_file_from_request(file_input: FileInput): if file_input is None: + log_error_log(SkyflowMessages.Error.INVALID_FILE_INPUT.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - - has_file = hasattr(file_input, 'file') and file_input.file is not None - has_file_path = hasattr(file_input, 'file_path') and file_input.file_path is not None - + + has_file = hasattr(file_input, FileUploadField.FILE) and file_input.file is not None + has_file_path = hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None + # Must provide exactly one of file or file_path if (has_file and has_file_path) or (not has_file and not has_file_path): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value, invalid_input_error_code) - + if has_file: file = file_input.file # Validate file object has required attributes - if not hasattr(file, 'name') or not isinstance(file.name, str) or not file.name.strip(): + if not hasattr(file, FileUploadField.NAME) or not isinstance(file.name, str) or not file.name.strip(): + log_error_log(SkyflowMessages.Error.INVALID_FILE_TYPE.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) - + # Validate file name file_name, _ = os.path.splitext(os.path.basename(file.name)) if not file_name or not file_name.strip(): + log_error_log(SkyflowMessages.Error.INVALID_FILE_NAME.value) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_NAME.value, invalid_input_error_code) - + elif has_file_path: file_path = file_input.file_path if not isinstance(file_path, str) or not file_path.strip(): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) - + if not os.path.exists(file_path) or not os.path.isfile(file_path): + log_error_log(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): - if not hasattr(request, 'file') or request.file is None: + if not hasattr(request, FileUploadField.FILE) or request.file is None: + log_error_log(SkyflowMessages.Error.INVALID_FILE_INPUT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - + # Validate file input first validate_file_from_request(request.file) # Optional: entities - if hasattr(request, 'entities') and request.entities is not None: + if hasattr(request, DeidentifyFileRequestField.ENTITIES) and request.entities is not None: if not isinstance(request.entities, list): + log_error_log(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) if not all(isinstance(entity, DetectEntities) for entity in request.entities): + log_error_log(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) # Optional: allow_regex_list - if hasattr(request, 'allow_regex_list') and request.allow_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.ALLOW_REGEX_LIST) and request.allow_regex_list is not None: if not isinstance(request.allow_regex_list, list) or not all(isinstance(x, str) for x in request.allow_regex_list): + log_error_log(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Optional: restrict_regex_list - if hasattr(request, 'restrict_regex_list') and request.restrict_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.RESTRICT_REGEX_LIST) and request.restrict_regex_list is not None: if not isinstance(request.restrict_regex_list, list) or not all(isinstance(x, str) for x in request.restrict_regex_list): + log_error_log(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) # Optional: token_format if request.token_format is not None and not isinstance(request.token_format, TokenFormat): + log_error_log(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, invalid_input_error_code) # Optional: transformations if request.transformations is not None and not isinstance(request.transformations, Transformations): + log_error_log(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) # Optional: output_processed_image - if hasattr(request, 'output_processed_image') and request.output_processed_image is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE) and request.output_processed_image is not None: if not isinstance(request.output_processed_image, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, invalid_input_error_code) # Optional: output_ocr_text - if hasattr(request, 'output_ocr_text') and request.output_ocr_text is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT) and request.output_ocr_text is not None: if not isinstance(request.output_ocr_text, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, invalid_input_error_code) # Optional: masking_method - # Optional: masking_method - if hasattr(request, 'masking_method') and request.masking_method is not None: + if hasattr(request, DeidentifyFileRequestField.MASKING_METHOD) and request.masking_method is not None: if not isinstance(request.masking_method, MaskingMethod): + log_error_log(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, invalid_input_error_code) # Optional: pixel_density - if hasattr(request, 'pixel_density') and request.pixel_density is not None: + if hasattr(request, DeidentifyFileRequestField.PIXEL_DENSITY) and request.pixel_density is not None: if not isinstance(request.pixel_density, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, invalid_input_error_code) # Optional: max_resolution - if hasattr(request, 'max_resolution') and request.max_resolution is not None: + if hasattr(request, DeidentifyFileRequestField.MAX_RESOLUTION) and request.max_resolution is not None: if not isinstance(request.max_resolution, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, invalid_input_error_code) # Optional: output_processed_audio - if hasattr(request, 'output_processed_audio') and request.output_processed_audio is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO) and request.output_processed_audio is not None: if not isinstance(request.output_processed_audio, bool): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, invalid_input_error_code) # Optional: output_transcription - if hasattr(request, 'output_transcription') and request.output_transcription is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION) and request.output_transcription is not None: if not isinstance(request.output_transcription, DetectOutputTranscriptions): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, invalid_input_error_code) # Optional: bleep - if hasattr(request, 'bleep') and request.bleep is not None: + if hasattr(request, DeidentifyFileRequestField.BLEEP) and request.bleep is not None: if not isinstance(request.bleep, Bleep): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, invalid_input_error_code) - + # Validate gain if request.bleep.gain is not None and not isinstance(request.bleep.gain, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_GAIN.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_GAIN.value, invalid_input_error_code) - + # Validate frequency if request.bleep.frequency is not None and not isinstance(request.bleep.frequency, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value, invalid_input_error_code) - + # Validate start_padding if request.bleep.start_padding is not None and not isinstance(request.bleep.start_padding, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value, invalid_input_error_code) - + # Validate stop_padding if request.bleep.stop_padding is not None and not isinstance(request.bleep.stop_padding, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, invalid_input_error_code) # Optional: output_directory - if hasattr(request, 'output_directory') and request.output_directory is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_DIRECTORY) and request.output_directory is not None: if not isinstance(request.output_directory, str): + log_error_log(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, invalid_input_error_code) if not os.path.isdir(request.output_directory): + log_error_log(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), logger) raise SkyflowError(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), invalid_input_error_code) # Optional: wait_time - if hasattr(request, 'wait_time') and request.wait_time is not None: + if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): + log_error_log(SkyflowMessages.Error.INVALID_WAIT_TIME.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 or request.wait_time > 64: + if request.wait_time < 0 or request.wait_time > Detect.WAIT_TIME: + log_error_log(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, logger) raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not isinstance(request.values, list) or not all(isinstance(v, dict) for v in request.values): - log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) - if not len(request.values): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format("INSERT"), logger=logger) + if not request.values: + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code) for i, item in enumerate(request.values, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger) - - if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) if request.homogeneous is not None and not isinstance(request.homogeneous, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_HOMOGENEOUS_TYPE.value, invalid_input_error_code) if request.upsert and request.homogeneous: - log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), logger = logger) - raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), logger = logger) + raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), invalid_input_error_code) if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): @@ -441,15 +499,15 @@ def validate_insert_request(logger, request): for i, item in enumerate(request.tokens, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format("INSERT"), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format(RequestOperation.INSERT), logger=logger) if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format("INSERT", key), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format(RequestOperation.INSERT, key), logger=logger) if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -459,43 +517,43 @@ def validate_insert_request(logger, request): raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE_STRICT: - if len(request.values) != len(request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("INSERT"), logger = logger) + if not request.tokens or len(request.values) != len(request.tokens): + log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) for v, t in zip(request.values, request.tokens): if set(v.keys()) != set(t.keys()): - log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format("INSERT"), logger=logger) - raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format(RequestOperation.INSERT), logger=logger) + raise SkyflowError(SkyflowMessages.Error.MISMATCH_OF_FIELDS_AND_TOKENS.value, invalid_input_error_code) def validate_delete_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not request.ids: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code) def validate_query_request(logger, request): - if not request.query: - log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format("QUERY"), logger = logger) - raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not isinstance(request.query, str): query_type = str(type(request.query)) raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code) + if not request.query: + log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger=logger) + raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) + if not request.query.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format("QUERY"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not request.query.upper().startswith("SELECT"): + if not request.query.upper().startswith(SqlCommand.SELECT): command = request.query - raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) def validate_get_request(logger, request): redaction_type = request.redaction_type @@ -508,23 +566,23 @@ def validate_get_request(logger, request): download_url = request.download_url if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not skyflow_ids and not column_name and not column_values: - log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) if skyflow_ids and (not isinstance(skyflow_ids, list) or not skyflow_ids): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_IDS_TYPE.value.format(type(skyflow_ids)), invalid_input_error_code) if skyflow_ids: for index, skyflow_id in enumerate(skyflow_ids): if skyflow_id is None or skyflow_id == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format("GET", index), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format(RequestOperation.GET, index), logger=logger) if not isinstance(request.return_tokens, bool): @@ -534,7 +592,7 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction_type)), invalid_input_error_code) if fields is not None and (not isinstance(fields, list) or not fields): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(fields)), invalid_input_error_code) if offset is not None and limit is not None: @@ -543,13 +601,13 @@ def validate_get_request(logger, request): invalid_input_error_code) if offset is not None and not isinstance(offset, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value(type(offset)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_OFF_SET_VALUE.value.format(type(offset)), invalid_input_error_code) if limit is not None and not isinstance(limit, str): - raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value(type(limit)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_LIMIT_VALUE.value.format(type(limit)), invalid_input_error_code) if download_url is not None and not isinstance(download_url, bool): - raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value(type(download_url)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_DOWNLOAD_URL_VALUE.value.format(type(download_url)), invalid_input_error_code) if column_name is not None and (not isinstance(column_name, str) or not column_name.strip()): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) @@ -560,61 +618,58 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if request.return_tokens and redaction_type: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value, invalid_input_error_code) if (column_name or column_values) and request.return_tokens: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format("GET"), + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.TOKENS_GET_COLUMN_NOT_SUPPORTED.value, invalid_input_error_code) if column_values and not column_name: - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if column_name and not column_values: log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format("GET"), logger = logger) - SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUES.value, invalid_input_error_code) if (column_name or column_values) and skyflow_ids: - log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code) def validate_update_request(logger, request): - skyflow_id = "" - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + if not isinstance(request.data, dict): + raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value.format(type(request.data)), invalid_input_error_code) - try: - skyflow_id = request.data.get("skyflow_id") - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format("UPDATE"), logger=logger) + if not len(request.data.items()): + raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) - if not skyflow_id.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format("UPDATE"), logger = logger) + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} + + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) + if skyflow_id is None: + log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) + elif not skyflow_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger=logger) if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("UPDATE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not isinstance(request.return_tokens, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code) - if not isinstance(request.data, dict): - raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_TYPE.value(type(request.data)), invalid_input_error_code) - - if not len(request.data.items()): - raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code) - if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value, invalid_input_error_code) if request.tokens: if not isinstance(request.tokens, dict) or not request.tokens: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -627,14 +682,14 @@ def validate_update_request(logger, request): if request.token_mode == TokenMode.ENABLE_STRICT: if len(field) != len(request.tokens): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) if set(field.keys()) != set(request.tokens.keys()): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError( SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, @@ -645,23 +700,33 @@ def validate_detokenize_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value, invalid_input_error_code) if not isinstance(request.data, list): - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - if not len(request.data): - log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format("DETOKENIZE"), logger = logger) - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("DETOKENIZE"), logger = logger) + if not request.data: + log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format(RequestOperation.DETOKENIZE), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.DETOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code) for item in request.data: - if 'token' not in item: + if ResponseField.TOKEN not in item: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - token = item.get('token') - redaction = item.get('redaction', None) + token = item.get(ResponseField.TOKEN) + + has_redaction = RequestParameter.REDACTION in item + has_redaction_type = RequestParameter.REDACTION_TYPE in item + + if has_redaction: + log_warn(SkyflowMessages.Warning.DETOKENIZE_REDACTION_KEY_DEPRECATED.value, logger) + + if has_redaction_type: + redaction = item.get(RequestParameter.REDACTION_TYPE) + else: + redaction = item.get(RequestParameter.REDACTION, None) if not isinstance(token, str) or not token: - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE"), + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format(RequestOperation.DETOKENIZE), invalid_input_error_code) if redaction is not None and not isinstance(redaction, RedactionType): @@ -673,23 +738,23 @@ def validate_tokenize_request(logger, request): if not isinstance(parameters, list): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(parameters)), invalid_input_error_code) - if not len(parameters): + if not parameters: raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value, invalid_input_error_code) for i, param in enumerate(parameters): if not isinstance(param, dict): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER.value.format(i, type(param)), invalid_input_error_code) - allowed_keys = {"value", "column_group"} + allowed_keys = {RequestParameter.VALUE, RequestParameter.COLUMN_GROUP} if set(param.keys()) != allowed_keys: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER_KEY.value.format(i), invalid_input_error_code) - if not param.get("value"): - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.VALUE): + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_VALUE.value.format(i), invalid_input_error_code) - if not param.get("column_group"): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.COLUMN_GROUP): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_COLUMN_GROUP.value.format(i), invalid_input_error_code) @@ -698,32 +763,30 @@ def validate_file_upload_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) # Table - table = getattr(request, "table", None) + table = getattr(request, FileUploadField.TABLE, None) if table is None: raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) elif table.strip() == "": raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) # Skyflow ID - skyflow_id = getattr(request, "skyflow_id", None) - if skyflow_id is None: - raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) - elif skyflow_id.strip() == "": - raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD"), invalid_input_error_code) + skyflow_id = getattr(request, FileUploadField.SKYFLOW_ID, None) + if skyflow_id is not None and skyflow_id.strip() == "": + raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format(RequestOperation.FILE_UPLOAD), invalid_input_error_code) # Column Name - column_name = getattr(request, "column_name", None) + column_name = getattr(request, FileUploadField.COLUMN_NAME, None) if column_name is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) elif column_name.strip() == "": - logger.error("Empty column name in FILE_UPLOAD") + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FILE_COLUMN_NAME.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) # File-related attributes - file_path = getattr(request, "file_path", None) - base64_str = getattr(request, "base64", None) - file_object = getattr(request, "file_object", None) - file_name = getattr(request, "file_name", None) + file_path = getattr(request, FileUploadField.FILE_PATH, None) + base64_str = getattr(request, FileUploadField.BASE64, None) + file_object = getattr(request, FileUploadField.FILE_OBJECT, None) + file_name = getattr(request, FileUploadField.FILE_NAME, None) # Check file_path first if present if not is_none_or_empty(file_path): @@ -775,46 +838,57 @@ def validate_invoke_connection_params(logger, query_params, path_params): except TypeError: raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_PARAMS.value, invalid_input_error_code) -def validate_deidentify_text_request(self, request: DeidentifyTextRequest): +def validate_deidentify_text_request(logger, request: DeidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): + log_error_log(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value, invalid_input_error_code) # Validate entities if present if request.entities is not None and not isinstance(request.entities, list): + log_error_log(SkyflowMessages.Error.INVALID_ENTITIES_IN_DEIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ENTITIES_IN_DEIDENTIFY.value, invalid_input_error_code) # Validate allowed_regex_list if present if request.allow_regex_list is not None and not isinstance(request.allow_regex_list, list): + log_error_log(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Validate restricted_regex_list if present if request.restrict_regex_list is not None and not isinstance(request.restrict_regex_list, list): + log_error_log(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) # Validate token_format if present if request.token_format is not None and not isinstance(request.token_format, TokenFormat): + log_error_log(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value, invalid_input_error_code) # Validate transformations if present if request.transformations is not None and not isinstance(request.transformations, Transformations): + log_error_log(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) -def validate_reidentify_text_request(self, request: ReidentifyTextRequest): +def validate_reidentify_text_request(logger, request: ReidentifyTextRequest): if not request.text or not isinstance(request.text, str) or not request.text.strip(): + log_error_log(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value, invalid_input_error_code) # Validate redacted_entities if present if request.redacted_entities is not None and not isinstance(request.redacted_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) # Validate masked_entities if present if request.masked_entities is not None and not isinstance(request.masked_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) # Validate plain_text_entities if present if request.plain_text_entities is not None and not isinstance(request.plain_text_entities, list): + log_error_log(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value, invalid_input_error_code) -def validate_get_detect_run_request(self, request: GetDetectRunRequest): - if not request.run_id or not isinstance(request.run_id, str) or not request.run_id.strip(): +def validate_get_detect_run_request(logger, request: GetDetectRunRequest): + if request.run_id is None or not isinstance(request.run_id, str) or not request.run_id.strip(): + log_error_log(SkyflowMessages.ErrorLogs.INVALID_RUN_ID.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_RUN_ID.value, invalid_input_error_code) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index 0304c11a..8023646c 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -1,7 +1,9 @@ +from skyflow.error import SkyflowError from skyflow.generated.rest.client import Skyflow from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages from skyflow.utils.logger import log_info +from skyflow.utils.constants import OptionField, CredentialField, ConfigField class VaultClient: @@ -34,18 +36,18 @@ def initialize_client_configuration(self): needs_reinit = self.__api_client is None or self.__is_config_updated if needs_reinit: - self.__credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger=self.__logger) - self.__vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env"), - self.__config.get("vault_id"), + self.__credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger=self.__logger) + self.__vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), logger=self.__logger) - self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials + self.__is_static_token = CredentialField.TOKEN in self.__credentials or CredentialField.API_KEY in self.__credentials bearer_token = self.get_bearer_token(self.__credentials) if needs_reinit: self.initialize_api_client(self.__vault_url, bearer_token) def initialize_api_client(self, vault_url, bearer_token): - token_provider = lambda: self.__bearer_token if self.__bearer_token else bearer_token # noqa: E731 + token_provider = lambda: self.__bearer_token if self.__bearer_token is not None else bearer_token # noqa: E731 self.__api_client = Skyflow(base_url=vault_url, token=token_provider) def get_records_api(self): @@ -64,28 +66,30 @@ def get_detect_file_api(self): return self.__api_client.files def get_vault_id(self): - return self.__config.get("vault_id") + return self.__config.get(ConfigField.VAULT_ID) def get_bearer_token(self, credentials): - if 'api_key' in credentials: - return credentials.get('api_key') - elif 'token' in credentials: - return credentials.get("token") + if CredentialField.API_KEY in credentials: + return credentials.get(CredentialField.API_KEY) + elif CredentialField.TOKEN in credentials: + return credentials.get(CredentialField.TOKEN) options = { - "role_ids": self.__config.get("roles"), - "ctx": self.__config.get("ctx") + OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), + OptionField.CTX: self.__config.get(OptionField.CTX) } + if CredentialField.TOKEN_URI_OPTION in credentials and credentials.get(CredentialField.TOKEN_URI_OPTION): + options[CredentialField.TOKEN_URI_OPTION] = credentials.get(CredentialField.TOKEN_URI_OPTION) if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token): - if 'path' in credentials: + if CredentialField.PATH in credentials: self.__bearer_token, _ = generate_bearer_token( - credentials.get("path"), + credentials.get(CredentialField.PATH), options, self.__logger ) else: - credentials_string = credentials.get('credentials_string') + credentials_string = credentials.get(CredentialField.CREDENTIALS_STRING) log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, self.__logger) self.__bearer_token, _ = generate_bearer_token_from_creds( credentials_string, diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 81c6ea10..2ce0c104 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,6 +5,8 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader, OptionField, ConfigField +from skyflow.utils import get_credentials class Connection: @@ -12,20 +14,22 @@ def __init__(self, vault_client): self.__vault_client = vault_client def invoke(self, request: InvokeConnectionRequest): - session = requests.Session() - - config = self.__vault_client.get_config() - bearer_token = self.__vault_client.get_bearer_token(config.get("credentials")) - - connection_url = config.get("connection_url") log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) + config = self.__vault_client.get_config() + connection_url = config.get(OptionField.CONNECTION_URL) invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) + + credentials = get_credentials(config.get(ConfigField.CREDENTIALS), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) + + bearer_token = self.__vault_client.get_bearer_token(credentials) + + session = requests.Session() - if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: - invoke_connection_request.headers['x-skyflow-authorization'] = bearer_token + if not HttpHeader.X_SKYFLOW_AUTHORIZATION_HEADER.lower() in invoke_connection_request.headers: + invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token - invoke_connection_request.headers['sky-metadata'] = json.dumps(get_metrics()) + invoke_connection_request.headers[SKY_META_DATA_HEADER] = json.dumps(get_metrics()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index cb5e8836..754e7799 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,7 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField, Detect as DetectConstants) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -29,44 +30,44 @@ def __get_headers(self): } return headers - def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: + def __build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[str, Any]: deidentify_text_body = {} parsed_entity_types = request.entities - deidentify_text_body['text'] = request.text - deidentify_text_body['entity_types'] = parsed_entity_types - deidentify_text_body['token_type'] = self.__get_token_format(request) - deidentify_text_body['allow_regex'] = request.allow_regex_list - deidentify_text_body['restrict_regex'] = request.restrict_regex_list - deidentify_text_body['transformations'] = self.__get_transformations(request) + deidentify_text_body[DeidentifyField.TEXT] = request.text + deidentify_text_body[DeidentifyField.ENTITY_TYPES] = parsed_entity_types + deidentify_text_body[DeidentifyField.TOKEN_TYPE] = self.__get_token_format(request) + deidentify_text_body[DeidentifyField.ALLOW_REGEX] = request.allow_regex_list + deidentify_text_body[DeidentifyField.RESTRICT_REGEX] = request.restrict_regex_list + deidentify_text_body[DeidentifyField.TRANSFORMATIONS] = self.__get_transformations(request) return deidentify_text_body - def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: + def __build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[str, Any]: parsed_format = Format( redacted=request.redacted_entities, masked=request.masked_entities, plaintext=request.plain_text_entities ) reidentify_text_body = {} - reidentify_text_body['text'] = request.text - reidentify_text_body['format'] = parsed_format + reidentify_text_body[DeidentifyField.TEXT] = request.text + reidentify_text_body[DeidentifyField.FORMAT] = parsed_format return reidentify_text_body def _get_file_extension(self, filename: str): return filename.split('.')[-1].lower() if '.' in filename else '' - def __poll_for_processed_file(self, run_id, max_wait_time=64): - max_wait_time = 64 if max_wait_time is None else max_wait_time + def __poll_for_processed_file(self, run_id, max_wait_time=None): + max_wait_time = DetectConstants.WAIT_TIME if max_wait_time is None else max_wait_time files_api = self.__vault_client.get_detect_file_api().with_raw_response current_wait_time = 1 # Start with 1 second try: while True: response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()}).data status = response.status - if status == 'IN_PROGRESS': + if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: - return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS') + return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: next_wait_time = current_wait_time * 2 if next_wait_time >= max_wait_time: @@ -76,42 +77,54 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): wait_time = next_wait_time current_wait_time = next_wait_time time.sleep(wait_time) - elif status == 'SUCCESS' or status == 'FAILED': + elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: - raise e + handle_exception(e, self.__vault_client.get_logger()) def __save_deidentify_file_response_output(self, response: DetectRunsResponse, output_directory: str, original_file_name: str, name_without_ext: str): - if not response or not hasattr(response, 'output') or not response.output or not output_directory: + if not response or not hasattr(response, DeidentifyField.OUTPUT) or not response.output or not output_directory: return if not os.path.exists(output_directory): return - deidentify_file_prefix = "processed-" + deidentify_file_prefix = FileProcessing.PROCESSED_PREFIX output_list = response.output base_original_filename = os.path.basename(original_file_name) base_name_without_ext = os.path.splitext(base_original_filename)[0] + real_output_dir = os.path.realpath(output_directory) for idx, output in enumerate(output_list): try: - processed_file = get_attribute(output, 'processedFile', 'processed_file') - processed_file_type = get_attribute(output, 'processedFileType', 'processed_file_type') - processed_file_extension = get_attribute(output, 'processedFileExtension', 'processed_file_extension') + processed_file = get_attribute(output, DeidentifyField.PROCESSED_FILE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE) + processed_file_type = get_attribute(output, DeidentifyField.PROCESSED_FILE_TYPE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_TYPE) + processed_file_extension = get_attribute(output, DeidentifyField.PROCESSED_FILE_EXTENSION_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_EXTENSION) if not processed_file: continue decoded_data = base64.b64decode(processed_file) - - if idx == 0 or processed_file_type == 'redacted_file': + + # Sanitize extension from API response to prevent path traversal (CWE-22). + # Avoid os.path.basename here to keep basename mock-free in tests. + safe_ext = None + if processed_file_extension: + raw_ext = str(processed_file_extension).replace('\\', '/').split('/')[-1].lstrip('.') + safe_ext = ''.join(c for c in raw_ext if c.isalnum() or c in ('-', '_')) or 'bin' + + if idx == 0 or processed_file_type == DeidentifyField.REDACTED_FILE: output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) - if processed_file_extension: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") + if safe_ext: + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext}") else: - output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") - + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{safe_ext or 'bin'}") + + if not os.path.realpath(output_file_name).startswith(real_output_dir + os.sep): + log_error_log(SkyflowMessages.ErrorLogs.SAVING_DEIDENTIFY_FILE_FAILED.value, self.__vault_client.get_logger()) + continue + with open(output_file_name, 'wb') as f: f.write(decoded_data) except Exception as e: @@ -119,62 +132,62 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o handle_exception(e, self.__vault_client.get_logger()) def __parse_deidentify_file_response(self, data, run_id=None, status=None): - output = getattr(data, "output", []) - status_val = getattr(data, "status", None) or status - run_id_val = getattr(data, "run_id", None) or run_id + output = getattr(data, DeidentifyField.OUTPUT, []) + status_val = getattr(data, DeidentifyField.STATUS, None) or status + run_id_val = getattr(data, DeidentifyField.RUN_ID, None) or run_id word_count = None char_count = None - word_character_count = getattr(data, "word_character_count", None) + word_character_count = getattr(data, DeidentifyField.WORD_CHARACTER_COUNT, None) if word_character_count and isinstance(word_character_count, WordCharacterCount): - word_count = word_character_count.word_count - char_count = word_character_count.character_count + word_count = getattr(word_character_count, DeidentifyField.WORD_COUNT, None) + char_count = getattr(word_character_count, DeidentifyField.CHARACTER_COUNT, None) - size = getattr(data, "size", None) + size = getattr(data, DeidentifyField.SIZE, None) size = float(size) if size is not None else None - duration = getattr(data, "duration", None) - pages = getattr(data, "pages", None) - slides = getattr(data, "slides", None) + duration = getattr(data, DeidentifyField.DURATION, None) + pages = getattr(data, DeidentifyField.PAGES, None) + slides = getattr(data, DeidentifyField.SLIDES, None) def output_to_dict_list(output): result = [] for o in output: if isinstance(o, dict): result.append({ - "file": o.get("processed_file"), - "type": o.get("processed_file_type"), - "extension": o.get("processed_file_extension") + DeidentifyField.FILE: o.get(DeidentifyField.PROCESSED_FILE), + DeidentifyField.TYPE: o.get(DeidentifyField.PROCESSED_FILE_TYPE), + DeidentifyField.EXTENSION: o.get(DeidentifyField.PROCESSED_FILE_EXTENSION) }) else: result.append({ - "file": getattr(o, "processed_file", None), - "type": getattr(o, "processed_file_type", None), - "extension": getattr(o, "processed_file_extension", None) + DeidentifyField.FILE: getattr(o, DeidentifyField.PROCESSED_FILE, None), + DeidentifyField.TYPE: getattr(o, DeidentifyField.PROCESSED_FILE_TYPE, None), + DeidentifyField.EXTENSION: getattr(o, DeidentifyField.PROCESSED_FILE_EXTENSION, None) }) return result output_list = output_to_dict_list(output) first_output = output_list[0] if output_list else {} - entities = [o for o in output_list if o.get("type") == "entities"] + entities = [o for o in output_list if o.get(DeidentifyField.TYPE) == FileProcessing.ENTITIES] - base64_string = first_output.get("file", None) - extension = first_output.get("extension", None) + base64_string = first_output.get(DeidentifyField.FILE, None) + extension = first_output.get(DeidentifyField.EXTENSION, None) if base64_string is not None: - file_bytes = base64.b64decode(base64_string) - file_obj = io.BytesIO(file_bytes) - file_obj.name = f"deidentified.{extension}" if extension else "processed_file" + file_bytes = base64.b64decode(base64_string) + file_obj = io.BytesIO(file_bytes) + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get("type", "UNKNOWN"), + type=first_output.get(DeidentifyField.TYPE, None), extension=extension, word_count=word_count, char_count=char_count, @@ -188,25 +201,26 @@ def output_to_dict_list(output): ) def __get_token_format(self, request): - if not hasattr(request, "token_format") or request.token_format is None: + if not hasattr(request, DeidentifyField.TOKEN_FORMAT) or request.token_format is None: return None return { - 'default': getattr(request.token_format, "default", None), - 'entity_unq_counter': getattr(request.token_format, "entity_unique_counter", None), - 'entity_only': getattr(request.token_format, "entity_only", None), + DeidentifyField.DEFAULT: getattr(request.token_format, DeidentifyField.DEFAULT, None), + DeidentifyField.ENTITY_UNQ_COUNTER: getattr(request.token_format, DeidentifyField.ENTITY_UNIQUE_COUNTER, None), + DeidentifyField.ENTITY_ONLY: getattr(request.token_format, DeidentifyField.ENTITY_ONLY, None), + DeidentifyField.VAULT_TOKEN: getattr(request.token_format, DeidentifyField.VAULT_TOKEN, None) } def __get_transformations(self, request): - if not hasattr(request, "transformations") or request.transformations is None: + if not hasattr(request, DeidentifyField.TRANSFORMATIONS) or request.transformations is None: return None - shift_dates = getattr(request.transformations, "shift_dates", None) + shift_dates = getattr(request.transformations, DeidentifyField.SHIFT_DATES, None) if shift_dates is None: return None return { - 'shift_dates': { - 'max_days': getattr(shift_dates, "max", None), - 'min_days': getattr(shift_dates, "min", None), - 'entity_types': getattr(shift_dates, "entities", None) + DeidentifyField.SHIFT_DATES: { + DeidentifyField.MAX_DAYS: getattr(shift_dates, DeidentifyField.MAX, None), + DeidentifyField.MIN_DAYS: getattr(shift_dates, DeidentifyField.MIN, None), + DeidentifyField.ENTITY_TYPES: getattr(shift_dates, DeidentifyField.ENTITIES, None) } } @@ -216,18 +230,18 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - deidentify_text_body = self.___build_deidentify_text_body(request) + deidentify_text_body = self.__build_deidentify_text_body(request) try: log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.deidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=deidentify_text_body['text'], - entity_types=deidentify_text_body['entity_types'], - allow_regex=deidentify_text_body['allow_regex'], - restrict_regex=deidentify_text_body['restrict_regex'], - token_type=deidentify_text_body['token_type'], - transformations=deidentify_text_body['transformations'], + text=deidentify_text_body[DeidentifyField.TEXT], + entity_types=deidentify_text_body[DeidentifyField.ENTITY_TYPES], + allow_regex=deidentify_text_body[DeidentifyField.ALLOW_REGEX], + restrict_regex=deidentify_text_body[DeidentifyField.RESTRICT_REGEX], + token_type=deidentify_text_body[DeidentifyField.TOKEN_TYPE], + transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], request_options={'additional_headers': self.__get_headers()} ) deidentify_text_response = parse_deidentify_text_response(api_response) @@ -244,14 +258,14 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() detect_api = self.__vault_client.get_detect_text_api() - reidentify_text_body = self.___build_reidentify_text_body(request) + reidentify_text_body = self.__build_reidentify_text_body(request) try: log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.reidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=reidentify_text_body['text'], - format=reidentify_text_body['format'], + text=reidentify_text_body[DeidentifyField.TEXT], + format=reidentify_text_body[DeidentifyField.FORMAT], request_options={'additional_headers': self.__get_headers()} ) reidentify_text_response = parse_reidentify_text_response(api_response) @@ -264,14 +278,16 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo def __get_file_from_request(self, request: DeidentifyFileRequest): file_input = request.file - - # Check for file - if hasattr(file_input, 'file') and file_input.file is not None: + + if hasattr(file_input, FileUploadField.FILE) and file_input.file is not None: return file_input.file - - # Check for file_path if file is not provided - if hasattr(file_input, 'file_path') and file_input.file_path is not None: - return open(file_input.file_path, 'rb') + + if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: + with open(file_input.file_path, 'rb') as f: + content = f.read() + bio = io.BytesIO(content) + bio.name = file_input.file_path + return bio def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) @@ -279,151 +295,152 @@ def deidentify_file(self, request: DeidentifyFileRequest): self.__initialize() files_api = self.__vault_client.get_detect_file_api().with_raw_response file_obj = self.__get_file_from_request(request) - file_name = getattr(file_obj, 'name', None) + file_name = getattr(file_obj, FileUploadField.NAME, None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode('utf-8') + base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) try: - if file_extension == 'txt': - req_file = FileDataDeidentifyText(base_64=base64_string, data_format="txt") + if file_extension == FileExtension.TXT: + req_file = FileDataDeidentifyText(base_64=base64_string, data_format=FileExtension.TXT) api_call = files_api.deidentify_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['mp3', 'wav']: + elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio + bleep = request.bleep api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'output_transcription': getattr(request, 'output_transcription', None), - 'output_processed_audio': getattr(request, 'output_processed_audio', None), - 'bleep_gain': getattr(request, 'bleep', None).gain if getattr(request, 'bleep', None) is not None else None, - 'bleep_frequency': getattr(request, 'bleep', None).frequency if getattr(request, 'bleep', None) is not None else None, - 'bleep_start_padding': getattr(request, 'bleep', None).start_padding if getattr(request, 'bleep', None) is not None else None, - 'bleep_stop_padding': getattr(request, 'bleep', None).stop_padding if getattr(request, 'bleep', None) is not None else None, - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION: getattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO, None), + DeidentifyField.BLEEP_GAIN: bleep.gain if bleep is not None else None, + DeidentifyField.BLEEP_FREQUENCY: bleep.frequency if bleep is not None else None, + DeidentifyField.BLEEP_START_PADDING: bleep.start_padding if bleep is not None else None, + DeidentifyField.BLEEP_STOP_PADDING: bleep.stop_padding if bleep is not None else None, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension == 'pdf': + elif file_extension == FileExtension.PDF: req_file = FileDataDeidentifyPdf(base_64=base64_string) api_call = files_api.deidentify_pdf api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'max_resolution': getattr(request, 'max_resolution', None), - 'density': getattr(request, 'pixel_density', None), - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MAX_RESOLUTION: getattr(request, DeidentifyFileRequestField.MAX_RESOLUTION, None), + DeidentifyFileRequestField.DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']: + elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: req_file = FileDataDeidentifyImage(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_image api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'masking_method': getattr(request, 'masking_method', None), - 'output_ocr_text': getattr(request, 'output_ocr_text', None), - 'output_processed_image': getattr(request, 'output_processed_image', None), - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MASKING_METHOD: getattr(request, DeidentifyFileRequestField.MASKING_METHOD, None), + DeidentifyFileRequestField.OUTPUT_OCR_TEXT: getattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE, None), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['ppt', 'pptx']: + elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: req_file = FileDataDeidentifyPresentation(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_presentation api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['csv', 'xls', 'xlsx']: + elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: req_file = FileDataDeidentifySpreadsheet(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_spreadsheet api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['doc', 'docx']: + elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: req_file = FileDataDeidentifyDocument(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_document api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } - elif file_extension in ['json', 'xml']: + elif file_extension in [FileExtension.JSON, FileExtension.XML]: req_file = FileDataDeidentifyStructuredText(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_structured_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } else: req_file = FileData(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_file api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': {'additional_headers': self.__get_headers()} + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: {'additional_headers': self.__get_headers()} } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) api_response = api_call(**api_kwargs) - run_id = getattr(api_response.data, 'run_id', None) + run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == 'SUCCESS': + if request.output_directory and processed_response.status == DetectStatus.SUCCESS and file_name: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -450,8 +467,8 @@ def get_detect_run(self, request: GetDetectRunRequest): vault_id=self.__vault_client.get_vault_id(), request_options={'additional_headers': self.__get_headers()} ) - if response.data.status == 'IN_PROGRESS': - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) + if response.data.status == DetectStatus.IN_PROGRESS: + parsed_response = DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) @@ -459,5 +476,4 @@ def get_detect_run(self, request: GetDetectRunRequest): except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value, self.__vault_client.get_logger()) - handle_exception(e, self.__vault_client.get_logger()) - + handle_exception(e, self.__vault_client.get_logger()) \ No newline at end of file diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index c757730a..6c47fe3e 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -8,7 +8,7 @@ from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter, FileUploadField from skyflow.utils.enums import RequestMethod from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log @@ -82,16 +82,14 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return (request.file_name, decoded_bytes) elif request.file_object is not None: - if hasattr(request.file_object, "name") and request.file_object.name: + if hasattr(request.file_object, FileUploadField.NAME) and request.file_object.name: file_name = os.path.basename(request.file_object.name) return (file_name, request.file_object) return None def __get_headers(self): - if not hasattr(self, '_cached_headers'): - self._cached_headers = {SKY_META_DATA_HEADER: json.dumps(get_metrics())} - return self._cached_headers + return {SKY_META_DATA_HEADER: json.dumps(get_metrics())} def insert(self, request: InsertRequest): log_info(SkyflowMessages.Info.VALIDATE_INSERT_REQUEST.value, self.__vault_client.get_logger()) @@ -124,7 +122,7 @@ def update(self, request: UpdateRequest): validate_update_request(self.__vault_client.get_logger(), request) log_info(SkyflowMessages.Info.UPDATE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} record = V1FieldRecords(fields=field, tokens = request.tokens) records_api = self.__vault_client.get_records_api() @@ -133,7 +131,7 @@ def update(self, request: UpdateRequest): api_response = records_api.record_service_update_record( self.__vault_client.get_vault_id(), request.table, - id=request.data.get("skyflow_id"), + id=request.data.get(ResponseField.SKYFLOW_ID), record=record, tokenization=request.return_tokens, byot=request.token_mode.value, @@ -224,8 +222,8 @@ def detokenize(self, request: DetokenizeRequest): self.__initialize() tokens_list = [ V1DetokenizeRecordRequest( - token=item.get('token'), - redaction=item.get('redaction', RedactionType.DEFAULT) + token=item.get(ResponseField.TOKEN), + redaction=item.get(RequestParameter.REDACTION_TYPE) or item.get(RequestParameter.REDACTION, RedactionType.DEFAULT) ) for item in request.data ] @@ -252,7 +250,7 @@ def tokenize(self, request: TokenizeRequest): self.__initialize() records_list = [ - V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"]) + V1TokenizeRecordRequest(value=item[RequestParameter.VALUE], column_group=item[RequestParameter.COLUMN_GROUP]) for item in request.values ] tokens_api = self.__vault_client.get_tokens_api() @@ -295,4 +293,4 @@ def upload_file(self, request: FileUploadRequest): return upload_response except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.FILE_UPLOAD_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger()) - handle_exception(e, self.__vault_client.get_logger()) + handle_exception(e, self.__vault_client.get_logger()) \ No newline at end of file diff --git a/skyflow/vault/data/_file_upload_request.py b/skyflow/vault/data/_file_upload_request.py index d1bd4a44..c5c08b51 100644 --- a/skyflow/vault/data/_file_upload_request.py +++ b/skyflow/vault/data/_file_upload_request.py @@ -1,14 +1,23 @@ -from typing import BinaryIO +from typing import BinaryIO, Optional + +from skyflow.utils import SkyflowMessages +from skyflow.utils.logger import log_warn + class FileUploadRequest: def __init__(self, table: str, - skyflow_id: str, - column_name: str, - file_path: str= None, - base64: str= None, - file_object: BinaryIO= None, - file_name: str= None): + *args, + column_name: Optional[str] = None, + skyflow_id: Optional[str] = None, + file_path: Optional[str] = None, + base64: Optional[str] = None, + file_object: Optional[BinaryIO] = None, + file_name: Optional[str] = None): + if args: + log_warn(SkyflowMessages.Warning.FILE_UPLOAD_REQUEST_ARG_ORDER_DEPRECATED.value) + skyflow_id = args[0] if args else skyflow_id + column_name = args[1] if len(args) > 1 else column_name self.table = table self.skyflow_id = skyflow_id self.column_name = column_name diff --git a/skyflow/vault/data/_get_response.py b/skyflow/vault/data/_get_response.py index cf1b0805..a1640254 100644 --- a/skyflow/vault/data/_get_response.py +++ b/skyflow/vault/data/_get_response.py @@ -1,6 +1,6 @@ class GetResponse: def __init__(self, data=None, errors = None): - self.data = data if data else [] + self.data = data if data is not None else [] self.errors = errors def __repr__(self): diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index b340e21c..e56f2113 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -1,22 +1,24 @@ import io +from typing import Optional from skyflow.vault.detect._file import File class DeidentifyFileResponse: def __init__( self, - file_base64: str = None, - file: io.BytesIO = None, - type: str = None, - extension: str = None, - word_count: int = None, - char_count: int = None, - size_in_kb: float = None, - duration_in_seconds: float = None, - page_count: int = None, - slide_count: int = None, - entities: list = None, # list of dicts with keys 'file' and 'extension' - run_id: str = None, - status: str = None, + file_base64: Optional[str] = None, + file: Optional[io.BytesIO] = None, + type: Optional[str] = None, + extension: Optional[str] = None, + word_count: Optional[int] = None, + char_count: Optional[int] = None, + size_in_kb: Optional[float] = None, + duration_in_seconds: Optional[float] = None, + page_count: Optional[int] = None, + slide_count: Optional[int] = None, + entities: Optional[list] = None, + run_id: Optional[str] = None, + status: Optional[str] = None, + errors: Optional[list] = None, ): self.file_base64 = file_base64 self.file = File(file) if file else None @@ -31,6 +33,7 @@ def __init__( self.entities = entities if entities is not None else [] self.run_id = run_id self.status = status + self.errors = errors def __repr__(self): return ( @@ -40,7 +43,7 @@ def __repr__(self): f"char_count={self.char_count!r}, size_in_kb={self.size_in_kb!r}, " f"duration_in_seconds={self.duration_in_seconds!r}, page_count={self.page_count!r}, " f"slide_count={self.slide_count!r}, entities={self.entities!r}, " - f"run_id={self.run_id!r}, status={self.status!r})" + f"run_id={self.run_id!r}, status={self.status!r}, errors={self.errors!r})" ) def __str__(self): diff --git a/skyflow/vault/detect/_deidentify_text_response.py b/skyflow/vault/detect/_deidentify_text_response.py index cdb6632e..227b43bc 100644 --- a/skyflow/vault/detect/_deidentify_text_response.py +++ b/skyflow/vault/detect/_deidentify_text_response.py @@ -1,19 +1,21 @@ -from typing import List +from typing import List, Optional from ._entity_info import EntityInfo class DeidentifyTextResponse: - def __init__(self, + def __init__(self, processed_text: str, - entities: List[EntityInfo], + entities: List[EntityInfo], word_count: int, - char_count: int): + char_count: int, + errors: Optional[list] = None): self.processed_text = processed_text self.entities = entities self.word_count = word_count self.char_count = char_count + self.errors = errors def __repr__(self): - return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count})" + return f"DeidentifyTextResponse(processed_text='{self.processed_text}', entities={self.entities}, word_count={self.word_count}, char_count={self.char_count}, errors={self.errors})" def __str__(self): return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/detect/_reidentify_text_response.py b/skyflow/vault/detect/_reidentify_text_response.py index 50c3876d..73ad3f5d 100644 --- a/skyflow/vault/detect/_reidentify_text_response.py +++ b/skyflow/vault/detect/_reidentify_text_response.py @@ -1,9 +1,12 @@ +from typing import Optional + class ReidentifyTextResponse: - def __init__(self, processed_text: str): + def __init__(self, processed_text: str, errors: Optional[list] = None): self.processed_text = processed_text + self.errors = errors def __repr__(self) -> str: - return f"ReidentifyTextResponse(processed_text='{self.processed_text}')" + return f"ReidentifyTextResponse(processed_text='{self.processed_text}', errors={self.errors})" def __str__(self) -> str: return self.__repr__() \ No newline at end of file diff --git a/tests/client/test_skyflow.py b/tests/client/test_skyflow.py index 3e3681bb..5b7ea675 100644 --- a/tests/client/test_skyflow.py +++ b/tests/client/test_skyflow.py @@ -1,42 +1,42 @@ import unittest -from unittest.mock import patch +from unittest.mock import patch, Mock from skyflow import LogLevel, Env from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow import Skyflow +from skyflow.vault.client.client import VaultClient +from skyflow.vault.data import FileUploadRequest VALID_VAULT_CONFIG = { "vault_id": "VAULT_ID", "cluster_id": "CLUSTER_ID", "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_VAULT_CONFIG = { "cluster_id": "CLUSTER_ID", # Missing vault_id "env": Env.DEV, - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } VALID_CONNECTION_CONFIG = { "connection_id": "CONNECTION_ID", "connection_url": "https://CONNECTION_URL", - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } INVALID_CONNECTION_CONFIG = { "connection_url": "https://CONNECTION_URL", # Missing connection_id - "credentials": {"path": "/path/to/valid_credentials.json"} + "credentials": {"path": "/path/to/valid_credentials.json"}, } -VALID_CREDENTIALS = { - "path": "/path/to/valid_credentials.json" -} +VALID_CREDENTIALS = {"path": "/path/to/valid_credentials.json"} -class TestSkyflow(unittest.TestCase): +class TestSkyflow(unittest.TestCase): def setUp(self): self.builder = Skyflow.builder() @@ -49,8 +49,10 @@ def test_add_already_exists_vault_config(self): builder = self.builder.add_vault_config(VALID_VAULT_CONFIG) with self.assertRaises(SkyflowError) as context: builder.add_vault_config(VALID_VAULT_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id"))) - + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(VALID_VAULT_CONFIG.get("vault_id")), + ) def test_add_vault_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -61,11 +63,11 @@ def test_add_vault_config_invalid(self): def test_remove_vault_config_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - result = self.builder.remove_vault_config(VALID_VAULT_CONFIG['vault_id']) + result = self.builder.remove_vault_config(VALID_VAULT_CONFIG["vault_id"]) - self.assertNotIn(VALID_VAULT_CONFIG['vault_id'], self.builder._Builder__vault_configs) + self.assertNotIn(VALID_VAULT_CONFIG["vault_id"], self.builder._Builder__vault_configs) - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_vault_config_invalid(self, mock_log_error): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -73,8 +75,7 @@ def test_remove_vault_config_invalid(self, mock_log_error): self.builder.remove_vault_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_VAULT_ID.value) - - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_vault_config_valid(self, mock_validate): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -94,7 +95,7 @@ def test_get_vault(self): def test_get_vault_with_vault_id_none(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() - vault = self.builder.get_vault_config(None) + vault = self.builder.get_vault_config(None) config = vault.get("vault_client").get_config() self.assertEqual(self.builder._Builder__vault_list[0], config) @@ -107,19 +108,23 @@ def test_get_vault_with_empty_vault_list_when_vault_id_is_none_raises_error(self def test_get_vault_with_invalid_vault_id_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_vault_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_vault_with_invalid_vault_id_and_non_empty_list_raises_error(self): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_vault_config('invalid_vault_id') - - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id")) + self.builder.get_vault_config("invalid_vault_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id"), + ) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_build_calls_validate_vault_config(self, mock_validate_vault_config): self.builder.add_vault_config(VALID_VAULT_CONFIG) self.builder.build() @@ -143,7 +148,9 @@ def test_add_already_exists_connection_config(self): with self.assertRaises(SkyflowError) as context: builder.add_connection_config(VALID_CONNECTION_CONFIG) - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id) + ) def test_add_connection_config_invalid(self): with self.assertRaises(SkyflowError) as context: @@ -158,8 +165,7 @@ def test_remove_connection_config_valid(self): self.assertNotIn(VALID_CONNECTION_CONFIG.get("connection_id"), self.builder._Builder__connection_configs) - - @patch('skyflow.utils.logger.log_error') + @patch("skyflow.utils.logger.log_error") def test_remove_connection_config_invalid(self, mock_log_error): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -167,7 +173,7 @@ def test_remove_connection_config_invalid(self, mock_log_error): self.builder.remove_connection_config("invalid_id") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONNECTION_ID.value) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_update_connection_config_valid(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() @@ -194,16 +200,21 @@ def test_get_connection_config_with_connection_id_none(self): def test_get_connection_with_empty_connection_list_raises_error(self): self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_id') - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_id')) + self.builder.get_connection_config("invalid_id") + self.assertEqual( + context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_id") + ) def test_get_connection_with_invalid_connection_id_raises_error(self): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() with self.assertRaises(SkyflowError) as context: - self.builder.get_connection_config('invalid_connection_id') + self.builder.get_connection_config("invalid_connection_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format('invalid_connection_id')) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("invalid_connection_id"), + ) def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(self): self.builder.build() @@ -212,13 +223,12 @@ def test_get_connection_with_invalid_connection_id_and_empty_list_raises_Error(s self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_CONFIGS.value) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_build_calls_validate_connection_config(self, mock_validate): self.builder.add_connection_config(VALID_CONNECTION_CONFIG) self.builder.build() mock_validate.assert_called_once_with(self.builder._Builder__logger, VALID_CONNECTION_CONFIG) - def test_build_valid(self): self.builder.add_vault_config(VALID_VAULT_CONFIG).add_connection_config(VALID_CONNECTION_CONFIG) client = self.builder.build() @@ -236,30 +246,31 @@ def test_invalid_credentials(self): self.assertEqual(VALID_CREDENTIALS, self.builder._Builder__skyflow_credentials) self.assertEqual(builder, self.builder) - @patch('skyflow.client.skyflow.validate_vault_config') + @patch("skyflow.client.skyflow.validate_vault_config") def test_skyflow_client_add_remove_vault_config(self, mock_validate_vault_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['vault_id'] = "VAULT_ID" + new_config["vault_id"] = "VAULT_ID" skyflow_client.add_vault_config(new_config) - assert mock_validate_vault_config.call_count == 2 + self.assertEqual(mock_validate_vault_config.call_count, 2) - self.assertEqual("VAULT_ID", - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id")) + self.assertEqual("VAULT_ID", skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id")) - skyflow_client.remove_vault_config(new_config['vault_id']) + skyflow_client.remove_vault_config(new_config["vault_id"]) with self.assertRaises(SkyflowError) as context: - skyflow_client.get_vault_config(new_config['vault_id']).get("vault_id") + skyflow_client.get_vault_config(new_config["vault_id"]).get("vault_id") - self.assertEqual(context.exception.message, SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format( - new_config['vault_id'])) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format(new_config["vault_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() new_config = VALID_VAULT_CONFIG.copy() - new_config['env'] = Env.SANDBOX + new_config["env"] = Env.SANDBOX skyflow_client.update_vault_config(new_config) mock_update_config.assert_called_once() @@ -267,29 +278,33 @@ def test_skyflow_client_update_and_get_vault_config(self, mock_update_config): self.assertEqual(VALID_VAULT_CONFIG.get("vault_id"), vault.get("vault_id")) - @patch('skyflow.client.skyflow.validate_connection_config') + @patch("skyflow.client.skyflow.validate_connection_config") def test_skyflow_client_add_remove_connection_config(self, mock_validate_connection_config): skyflow_client = self.builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_id'] = "CONNECTION_ID" + new_config["connection_id"] = "CONNECTION_ID" skyflow_client.add_connection_config(new_config) - assert mock_validate_connection_config.call_count == 2 - self.assertEqual("CONNECTION_ID", skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id")) + self.assertEqual(mock_validate_connection_config.call_count, 2) + self.assertEqual( + "CONNECTION_ID", skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + ) skyflow_client.remove_connection_config("CONNECTION_ID") with self.assertRaises(SkyflowError) as context: - skyflow_client.get_connection_config(new_config['connection_id']).get("connection_id") - - self.assertEqual(context.exception.message, SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config['connection_id'])) + skyflow_client.get_connection_config(new_config["connection_id"]).get("connection_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format(new_config["connection_id"]), + ) - @patch('skyflow.vault.client.client.VaultClient.update_config') + @patch("skyflow.vault.client.client.VaultClient.update_config") def test_skyflow_client_update_and_get_connection_config(self, mock_update_config): builder = self.builder skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() new_config = VALID_CONNECTION_CONFIG.copy() - new_config['connection_url'] = 'updated_url' + new_config["connection_url"] = "updated_url" skyflow_client.update_connection_config(new_config) mock_update_config.assert_called_once() @@ -305,28 +320,165 @@ def test_skyflow_add_and_update_skyflow_credentials(self): self.assertEqual(VALID_CREDENTIALS, builder._Builder__skyflow_credentials) new_credentials = VALID_CREDENTIALS.copy() - new_credentials['path'] = 'path/to/new_credentials' + new_credentials["path"] = "path/to/new_credentials" skyflow_client.update_skyflow_credentials(new_credentials) self.assertEqual(new_credentials, builder._Builder__skyflow_credentials) - def test_skyflow_add_and_update_log_level(self): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() + skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).build() skyflow_client.set_log_level(LogLevel.INFO) self.assertEqual(LogLevel.INFO, builder._Builder__log_level) - skyflow_client.update_log_level(LogLevel.ERROR) - self.assertEqual(LogLevel.ERROR, builder._Builder__log_level) - - - @patch('skyflow.client.Skyflow.Builder.get_vault_config') + @patch("skyflow.client.Skyflow.Builder.get_vault_config") def test_skyflow_vault_and_connection_method(self, mock_get_vault_config): builder = self.builder - skyflow_client = builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + skyflow_client = ( + builder.add_connection_config(VALID_CONNECTION_CONFIG).add_vault_config(VALID_VAULT_CONFIG).build() + ) skyflow_client.vault() skyflow_client.connection() - mock_get_vault_config.assert_called_once() \ No newline at end of file + mock_get_vault_config.assert_called_once() + + def test_detect_returns_detect_controller(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + from skyflow.vault.controller import Detect + result = skyflow_client.detect() + self.assertIsInstance(result, Detect) + + def test_detect_with_explicit_vault_id(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + from skyflow.vault.controller import Detect + result = skyflow_client.detect(VALID_VAULT_CONFIG["vault_id"]) + self.assertIsInstance(result, Detect) + + def test_detect_with_invalid_vault_id_raises_error(self): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + with self.assertRaises(SkyflowError) as context: + skyflow_client.detect("invalid_vault_id") + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("invalid_vault_id"), + ) + + @patch("skyflow.vault.client.client.VaultClient.update_config") + def test_update_vault_config_with_invalid_vault_id_raises_error(self, _mock): + skyflow_client = self.builder.add_vault_config(VALID_VAULT_CONFIG).build() + invalid_config = VALID_VAULT_CONFIG.copy() + invalid_config["vault_id"] = "non_existent_vault_id" + with self.assertRaises(SkyflowError) as context: + skyflow_client.update_vault_config(invalid_config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.VAULT_ID_NOT_IN_CONFIG_LIST.value.format("non_existent_vault_id"), + ) + + @patch("skyflow.vault.client.client.VaultClient.update_config") + def test_update_connection_config_with_invalid_connection_id_raises_error(self, _mock): + skyflow_client = self.builder.add_connection_config(VALID_CONNECTION_CONFIG).build() + invalid_config = VALID_CONNECTION_CONFIG.copy() + invalid_config["connection_id"] = "non_existent_connection_id" + with self.assertRaises(SkyflowError) as context: + skyflow_client.update_connection_config(invalid_config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.CONNECTION_ID_NOT_IN_CONFIG_LIST.value.format("non_existent_connection_id"), + ) + + +class TestVaultClient(unittest.TestCase): + def _make_client(self): + client = VaultClient({"vault_id": "test_vault"}) + client._VaultClient__api_client = Mock() + return client + + def test_get_detect_text_api_returns_strings(self): + client = self._make_client() + result = client.get_detect_text_api() + self.assertEqual(result, client._VaultClient__api_client.strings) + + def test_get_detect_file_api_returns_files(self): + client = self._make_client() + result = client.get_detect_file_api() + self.assertEqual(result, client._VaultClient__api_client.files) + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.is_expired", return_value=True) + def test_get_bearer_token_passes_token_uri_option(self, _mock_expired, mock_gen): + mock_gen.return_value = ("test_token", "bearer") + client = VaultClient({"vault_id": "test_vault"}) + credentials = { + "credentials_string": '{"clientID":"id","privateKey":"pk","keyID":"kid","tokenURI":"https://token.uri"}', + "token_uri": "https://custom-token-uri.com/token", + } + client.get_bearer_token(credentials) + options_passed = mock_gen.call_args[0][1] + self.assertIn("token_uri", options_passed) + self.assertEqual(options_passed["token_uri"], "https://custom-token-uri.com/token") + + +class TestUpdateLogLevelDeprecation(unittest.TestCase): + def _build_client(self): + return Skyflow.builder().add_vault_config(VALID_VAULT_CONFIG).build() + + def test_update_log_level_emits_deprecation_warning(self): + client = self._build_client() + with patch('skyflow.client.skyflow.log_warn') as mock_warn: + client.update_log_level(LogLevel.INFO) + mock_warn.assert_called_once() + self.assertIn("set_log_level", mock_warn.call_args[0][0]) + + def test_update_log_level_delegates_to_set_log_level(self): + client = self._build_client() + client.update_log_level(LogLevel.INFO) + self.assertEqual(client.get_log_level(), LogLevel.INFO) + + +class TestFileUploadRequestDeprecation(unittest.TestCase): + def test_keyword_args_no_warning(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest( + table="table", + column_name="col", + skyflow_id="sky123", + ) + mock_warn.assert_not_called() + self.assertEqual(req.table, "table") + self.assertEqual(req.column_name, "col") + self.assertEqual(req.skyflow_id, "sky123") + + def test_only_table_positional_no_warning(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest("table", column_name="col", skyflow_id="sky123") + mock_warn.assert_not_called() + self.assertEqual(req.table, "table") + + def test_old_positional_order_emits_deprecation_warning(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest("table", "sky123", "col") + mock_warn.assert_called_once() + self.assertIn("FileUploadRequest", mock_warn.call_args[0][0]) + + def test_old_positional_order_remaps_args_correctly(self): + req = FileUploadRequest("table", "sky123", "col") + self.assertEqual(req.skyflow_id, "sky123") + self.assertEqual(req.column_name, "col") + + def test_single_positional_arg_emits_warning_and_sets_skyflow_id(self): + with patch('skyflow.vault.data._file_upload_request.log_warn') as mock_warn: + req = FileUploadRequest("table", "sky123") + mock_warn.assert_called_once() + self.assertEqual(req.skyflow_id, "sky123") + self.assertIsNone(req.column_name) + + def test_optional_fields_default_to_none(self): + req = FileUploadRequest(table="table") + self.assertIsNone(req.skyflow_id) + self.assertIsNone(req.column_name) + self.assertIsNone(req.file_path) + self.assertIsNone(req.base64) + self.assertIsNone(req.file_object) + self.assertIsNone(req.file_name) diff --git a/tests/service_account/test__utils.py b/tests/service_account/test__utils.py index 73747d69..856d26bb 100644 --- a/tests/service_account/test__utils.py +++ b/tests/service_account/test__utils.py @@ -5,35 +5,57 @@ from unittest.mock import patch import os from skyflow.error import SkyflowError -from skyflow.service_account import is_expired, generate_bearer_token, \ - generate_bearer_token_from_creds +from skyflow.service_account import is_expired, generate_bearer_token, generate_bearer_token_from_creds from skyflow.utils import SkyflowMessages -from skyflow.service_account._utils import get_service_account_token, get_signed_jwt, generate_signed_data_tokens, get_signed_data_token_response_object, generate_signed_data_tokens_from_creds, _validate_and_resolve_ctx +from skyflow.service_account._utils import ( + get_service_account_token, + get_signed_jwt, + generate_signed_data_tokens, + get_signed_data_token_response_object, + generate_signed_data_tokens_from_creds, + _validate_and_resolve_ctx, + _normalize_credentials, + get_signed_tokens, +) creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) VALID_CREDENTIALS_STRING = json.dumps(credentials) -CREDENTIALS_WITHOUT_CLIENT_ID = { - 'privateKey': 'private_key' -} +CREDENTIALS_WITHOUT_CLIENT_ID = {"privateKey": "private_key"} -CREDENTIALS_WITHOUT_KEY_ID = { - 'privateKey': 'private_key', - 'clientID': 'client_id' -} +CREDENTIALS_WITHOUT_KEY_ID = {"privateKey": "private_key", "clientID": "client_id"} -CREDENTIALS_WITHOUT_TOKEN_URI = { - 'privateKey': 'private_key', - 'clientID': 'client_id', - 'keyID': 'key_id' -} +CREDENTIALS_WITHOUT_TOKEN_URI = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id"} VALID_SERVICE_ACCOUNT_CREDS = credentials +# Snake-case version of the real credentials (keys remapped to snake_case) +SNAKE_CASE_CREDS = { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], +} + +SNAKE_CASE_CREDS_STRING = json.dumps( + { + "private_key": credentials["privateKey"], + "client_id": credentials["clientID"], + "key_id": credentials["keyID"], + "token_uri": credentials["tokenURI"], + } +) + + class TestServiceAccountUtils(unittest.TestCase): + # ── is_expired ──────────────────────────────────────────────────────────── + + def test_is_expired_none_token(self): + self.assertTrue(is_expired(None)) + def test_is_expired_empty_token(self): self.assertTrue(is_expired("")) @@ -44,7 +66,7 @@ def test_is_expired_non_expired_token(self): def test_is_expired_expired_token(self): past_time = time.time() - 1000 - token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") + token = jwt.encode({"exp": past_time}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) @patch("skyflow.utils.logger._log_helpers.log_error_log") @@ -53,6 +75,8 @@ def test_is_expired_general_exception(self, mock_jwt_decode, mock_log_error): token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") self.assertTrue(is_expired(token)) + # ── generate_bearer_token ───────────────────────────────────────────────── + @patch("builtins.open", side_effect=FileNotFoundError) def test_generate_bearer_token_invalid_file_path(self, mock_open): with self.assertRaises(SkyflowError) as context: @@ -72,6 +96,8 @@ def test_generate_bearer_token_valid_file_path(self, mock_generate_bearer_token) generate_bearer_token(creds_path) mock_generate_bearer_token.assert_called_once() + # ── generate_bearer_token_from_creds ────────────────────────────────────── + @patch("skyflow.service_account._utils.get_service_account_token") def test_generate_bearer_token_from_creds_with_valid_json_string(self, mock_generate_bearer_token): generate_bearer_token_from_creds(VALID_CREDENTIALS_STRING) @@ -82,10 +108,11 @@ def test_generate_bearer_token_from_creds_invalid_json(self): generate_bearer_token_from_creds("invalid_json") self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) + # ── get_service_account_token ───────────────────────────────────────────── + def test_get_service_account_token_missing_private_key(self): - incomplete_credentials = {} with self.assertRaises(SkyflowError) as context: - get_service_account_token(incomplete_credentials, {}, None) + get_service_account_token({}, {}, None) self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) def test_get_service_account_token_missing_client_id_key(self): @@ -107,6 +134,102 @@ def test_get_service_account_token_with_valid_credentials(self): access_token, _ = get_service_account_token(VALID_SERVICE_ACCOUNT_CREDS, {}, None) self.assertTrue(access_token) + def test_get_service_account_token_with_snake_case_creds(self): + access_token, _ = get_service_account_token(SNAKE_CASE_CREDS, {}, None) + self.assertTrue(access_token) + + def test_get_service_account_token_missing_private_key_snake(self): + creds = { + "client_id": "id", + "key_id": "kid", + "token_uri": "https://example.com", + } + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_PRIVATE_KEY.value) + + def test_get_service_account_token_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_service_account_token_invalid_token_uri_in_options(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "not-a-valid-url"} + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, options, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_role_ids_formats_scope(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"role_ids": ["role1", "role2"]} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + access_token, token_type = get_service_account_token(creds, options, None) + self.assertEqual(access_token, "token") + self.assertEqual(token_type, "bearer") + args, kwargs = mock_auth_api.authentication_service_get_auth_token.call_args + self.assertIn("scope", kwargs) + self.assertEqual(kwargs["scope"], "role:role1 role:role2") + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_unauthorized_error(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + from skyflow.generated.rest.errors.unauthorized_error import UnauthorizedError + + mock_auth_api.authentication_service_get_auth_token.side_effect = UnauthorizedError("unauthorized") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.UNAUTHORIZED_ERROR_IN_GETTING_BEARER_TOKEN.value + ) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_generic_exception(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.side_effect = Exception("some error") + with self.assertRaises(SkyflowError) as context: + get_service_account_token(creds, {}, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.FAILED_TO_GET_BEARER_TOKEN.value) + + # ── get_signed_jwt ──────────────────────────────────────────────────────── @patch("jwt.encode", side_effect=Exception) def test_get_signed_jwt_invalid_format(self, mock_jwt_encode): @@ -135,25 +258,157 @@ def test_get_signed_jwt_with_empty_string_ctx_not_added(self, mock_jwt_encode): payload = mock_jwt_encode.call_args.kwargs["payload"] self.assertNotIn("ctx", payload) + # ── get_signed_data_token_response_object ───────────────────────────────── + def test_get_signed_data_token_response_object(self): token = "sample_token" signed_token = "signed_sample_token" response = get_signed_data_token_response_object(signed_token, token) + self.assertIsInstance(response, tuple) self.assertEqual(response[0], token) self.assertEqual(response[1], signed_token) + # ── get_signed_tokens ───────────────────────────────────────────────────── + + @patch("jwt.encode", side_effect=Exception("jwt error")) + def test_get_signed_tokens_jwt_encode_exception(self, mock_jwt_encode): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"data_tokens": ["token1"]} + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) + + def test_get_signed_tokens_returns_list_one_per_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + def test_get_signed_tokens_items_are_tuples_with_token_and_signed_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1", "token2"]}) + for item in result: + self.assertIsInstance(item, tuple) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[1][0], "token2") + self.assertTrue(result[0][1].startswith("signed_token_")) + self.assertTrue(result[1][1].startswith("signed_token_")) + + def test_get_signed_tokens_returns_list_single_token(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": ["token1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + def test_get_signed_tokens_empty_data_tokens_returns_empty_list(self): + result = generate_signed_data_tokens(creds_path, {"data_tokens": []}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_string_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": "my_ctx"}) + call_args = mock_jwt_encode.call_args + claims = call_args[0][0] if call_args[0] else call_args.kwargs.get("args", [None])[0] + # jwt.encode(claims, key, algorithm=...) — first positional arg is claims + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], "my_ctx") + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_dict_ctx_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + ctx_dict = {"role": "admin", "dept": "eng"} + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ctx_dict}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertEqual(claims_arg["ctx"], ctx_dict) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_empty_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": ""}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + @patch("skyflow.service_account._utils.jwt.encode") + def test_get_signed_tokens_with_none_ctx_not_in_claims(self, mock_jwt_encode): + mock_jwt_encode.return_value = "signed" + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "https://valid-url.com", + } + get_signed_tokens(creds, {"data_tokens": ["tok1"], "ctx": None}) + claims_arg = mock_jwt_encode.call_args[0][0] + self.assertNotIn("ctx", claims_arg) + + def test_get_signed_tokens_invalid_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + "tokenURI": "not-a-url", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_missing_token_uri(self): + creds = { + "privateKey": "key", + "clientID": "id", + "keyID": "kid", + } + with self.assertRaises(SkyflowError) as context: + get_signed_tokens(creds, {"data_tokens": ["tok1"]}) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_get_signed_tokens_with_snake_case_creds(self): + result = get_signed_tokens(SNAKE_CASE_CREDS, {"data_tokens": ["token1", "token2"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + + # ── generate_signed_data_tokens (file path) ─────────────────────────────── + def test_generate_signed_data_tokens_from_file_path(self): - creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") - options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'} + options = {"data_tokens": ["token1", "token2"], "ctx": "ctx"} result = generate_signed_data_tokens(creds_path, options) self.assertEqual(len(result), 2) def test_generate_signed_data_tokens_from_invalid_file_path(self): options = {"data_tokens": ["token1", "token2"]} with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens('credentials1.json', options) + generate_signed_data_tokens("credentials1.json", options) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value) + def test_generate_signed_data_tokens_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "department": "finance"}} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 1) + + # ── generate_signed_data_tokens_from_creds (string) ────────────────────── + def test_generate_signed_data_tokens_from_creds(self): options = {"data_tokens": ["token1", "token2"]} result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) @@ -161,22 +416,95 @@ def test_generate_signed_data_tokens_from_creds(self): def test_generate_signed_data_tokens_from_creds_with_invalid_string(self): options = {"data_tokens": ["token1", "token2"]} - credentials_string = '{' with self.assertRaises(SkyflowError) as context: - result = generate_signed_data_tokens_from_creds(credentials_string, options) + generate_signed_data_tokens_from_creds("{", options) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value) - # ctx JSON object support tests + def test_generate_signed_data_tokens_from_creds_with_dict_ctx(self): + options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "level": 3}} + result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) + self.assertEqual(len(result), 1) + + # ── snake_case end-to-end ───────────────────────────────────────────────── + + def test_generate_signed_data_tokens_with_snake_creds_file(self): + """generate_signed_data_tokens reads the file (camelCase) but the normalize fn is a no-op for camelCase.""" + options = {"data_tokens": ["token1", "token2"]} + result = generate_signed_data_tokens(creds_path, options) + self.assertEqual(len(result), 2) + + def test_generate_signed_data_tokens_from_creds_snake(self): + result = generate_signed_data_tokens_from_creds(SNAKE_CASE_CREDS_STRING, options={"data_tokens": ["t1"]}) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + + # ── _normalize_credentials ──────────────────────────────────────────────── + + def test_normalize_credentials_snake_case(self): + snake = { + "private_key": "pk", + "client_id": "cid", + "key_id": "kid", + "token_uri": "https://uri", + "client_name": "name", + } + result = _normalize_credentials(snake) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertEqual(result["clientName"], "name") + self.assertNotIn("private_key", result) + self.assertNotIn("client_id", result) + self.assertNotIn("key_id", result) + self.assertNotIn("token_uri", result) + self.assertNotIn("client_name", result) + + def test_normalize_credentials_camel_case_unchanged(self): + camel = { + "privateKey": "pk", + "clientID": "cid", + "keyID": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(camel) + self.assertEqual(result, camel) + + def test_normalize_credentials_mixed_keys(self): + mixed = { + "private_key": "pk", + "clientID": "cid", + "key_id": "kid", + "tokenURI": "https://uri", + } + result = _normalize_credentials(mixed) + self.assertEqual(result["privateKey"], "pk") + self.assertEqual(result["clientID"], "cid") + self.assertEqual(result["keyID"], "kid") + self.assertEqual(result["tokenURI"], "https://uri") + self.assertNotIn("private_key", result) + self.assertNotIn("key_id", result) + + def test_normalize_credentials_unknown_key_passes_through(self): + creds = {"unknown_field": "value", "anotherField": "val2"} + result = _normalize_credentials(creds) + self.assertEqual(result["unknown_field"], "value") + self.assertEqual(result["anotherField"], "val2") + + def test_normalize_credentials_empty_dict(self): + self.assertEqual(_normalize_credentials({}), {}) + + # ── _validate_and_resolve_ctx ───────────────────────────────────────────── def test_validate_and_resolve_ctx_none(self): self.assertIsNone(_validate_and_resolve_ctx(None)) def test_validate_and_resolve_ctx_empty_string(self): - self.assertIsNone(_validate_and_resolve_ctx('')) - self.assertIsNone(_validate_and_resolve_ctx(' ')) + self.assertIsNone(_validate_and_resolve_ctx("")) + self.assertIsNone(_validate_and_resolve_ctx(" ")) def test_validate_and_resolve_ctx_valid_string(self): - self.assertEqual(_validate_and_resolve_ctx('user_12345'), 'user_12345') + self.assertEqual(_validate_and_resolve_ctx("user_12345"), "user_12345") def test_validate_and_resolve_ctx_empty_dict(self): self.assertIsNone(_validate_and_resolve_ctx({})) @@ -190,19 +518,16 @@ def test_validate_and_resolve_ctx_dict_with_alphanumeric_keys(self): self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) def test_validate_and_resolve_ctx_dict_with_invalid_key_hyphen(self): - ctx = {"valid_key": "value", "invalid-key": "value"} with self.assertRaises(SkyflowError): - _validate_and_resolve_ctx(ctx) + _validate_and_resolve_ctx({"valid_key": "value", "invalid-key": "value"}) def test_validate_and_resolve_ctx_dict_with_invalid_key_space(self): - ctx = {"invalid key": "value"} with self.assertRaises(SkyflowError): - _validate_and_resolve_ctx(ctx) + _validate_and_resolve_ctx({"invalid key": "value"}) def test_validate_and_resolve_ctx_dict_with_invalid_key_dot(self): - ctx = {"invalid.key": "value"} with self.assertRaises(SkyflowError): - _validate_and_resolve_ctx(ctx) + _validate_and_resolve_ctx({"invalid.key": "value"}) def test_validate_and_resolve_ctx_valid_type_int(self): self.assertEqual(_validate_and_resolve_ctx(42), 42) @@ -228,13 +553,40 @@ def test_validate_and_resolve_ctx_dict_with_nested_objects(self): ctx = {"role": "admin", "metadata": {"level": 2, "tags": ["a", "b"]}} self.assertEqual(_validate_and_resolve_ctx(ctx), ctx) - def test_generate_signed_data_tokens_with_dict_ctx(self): - creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") - options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "department": "finance"}} - result = generate_signed_data_tokens(creds_path, options) - self.assertEqual(len(result), 2) + # ── additional coverage gaps ────────────────────────────────────────────── - def test_generate_signed_data_tokens_from_creds_with_dict_ctx(self): - options = {"data_tokens": ["token1"], "ctx": {"role": "admin", "level": 3}} - result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options) - self.assertEqual(len(result), 2) \ No newline at end of file + @patch("skyflow.service_account._utils.jwt.decode", side_effect=jwt.ExpiredSignatureError) + def test_is_expired_expired_signature_error(self, mock_decode): + token = jwt.encode({"exp": time.time() + 1000}, key="test", algorithm="HS256") + self.assertTrue(is_expired(token)) + + @patch("skyflow.service_account._utils.AuthClient") + @patch("skyflow.service_account._utils.get_signed_jwt") + def test_get_service_account_token_with_token_uri_option_override(self, mock_get_signed_jwt, mock_auth_client): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + override_uri = "https://override-url.com" + options = {"token_uri": override_uri} + mock_get_signed_jwt.return_value = "signed" + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + get_service_account_token(creds, options, None) + mock_get_signed_jwt.assert_called_once() + call_args = mock_get_signed_jwt.call_args + self.assertEqual(call_args[0][3], override_uri) + + @patch("json.load", side_effect=json.JSONDecodeError("bad json", "", 0)) + def test_generate_signed_data_tokens_from_file_invalid_json(self, mock_load): + invalid_path = os.path.join(os.path.dirname(__file__), "invalid_creds.json") + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(invalid_path, {"data_tokens": ["t1"]}) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.FILE_INVALID_JSON.value.format(invalid_path), + ) \ No newline at end of file diff --git a/tests/utils/test__helpers.py b/tests/utils/test__helpers.py index 8b55abf3..6016c798 100644 --- a/tests/utils/test__helpers.py +++ b/tests/utils/test__helpers.py @@ -1,5 +1,5 @@ import unittest -from skyflow.utils import get_base_url, format_scope +from skyflow.utils import get_base_url, format_scope, is_valid_url VALID_URL = "https://example.com/path?query=1" BASE_URL = "https://example.com" @@ -35,4 +35,28 @@ def test_format_scope_single_scope(self): def test_format_scope_special_characters(self): scopes_with_special_chars = ["admin", "user:write", "read-only"] expected_result = "role:admin role:user:write role:read-only" - self.assertEqual(format_scope(scopes_with_special_chars), expected_result) \ No newline at end of file + self.assertEqual(format_scope(scopes_with_special_chars), expected_result) + + def test_is_valid_url_valid(self): + self.assertTrue(is_valid_url("https://example.com")) + self.assertTrue(is_valid_url("https://example.com/path")) + + def test_is_valid_url_invalid(self): + self.assertFalse(is_valid_url("http://example.com")) + self.assertFalse(is_valid_url("ftp://example.com")) + self.assertFalse(is_valid_url("example.com")) + self.assertFalse(is_valid_url("invalid-url")) + self.assertFalse(is_valid_url("")) + + def test_is_valid_url_none(self): + self.assertFalse(is_valid_url(None)) + + def test_is_valid_url_no_scheme(self): + self.assertFalse(is_valid_url("www.example.com")) + + def test_is_valid_url_exception(self): + class BadStr: + def __str__(self): + raise Exception("bad str") + + self.assertFalse(is_valid_url(BadStr())) \ No newline at end of file diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6eaacf47..1363ad7d 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1,38 +1,65 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock, PropertyMock import os -import json -from unittest.mock import MagicMock from urllib.parse import quote +import tempfile, json from requests import PreparedRequest from requests.models import HTTPError from skyflow.error import SkyflowError from skyflow.generated.rest import ErrorResponse -from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \ - parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \ - parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \ - handle_exception, validate_api_key, encode_column_values, parse_deidentify_text_response, \ - parse_reidentify_text_response, convert_detected_entity_to_entity_info -from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error +from skyflow.service_account import ( + generate_bearer_token, + generate_signed_data_tokens, + generate_signed_data_tokens_from_creds, + generate_bearer_token_from_creds, +) +from skyflow.utils import ( + get_credentials, + SkyflowMessages, + get_vault_url, + construct_invoke_connection_request, + parse_insert_response, + parse_update_record_response, + parse_delete_response, + parse_get_response, + parse_detokenize_response, + parse_tokenize_response, + parse_query_response, + parse_invoke_connection_response, + handle_exception, + validate_api_key, + encode_column_values, + parse_deidentify_text_response, + parse_reidentify_text_response, + convert_detected_entity_to_entity_info, +) +from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics, handle_json_error, r_urlencode from skyflow.utils.enums import EnvUrls, Env, ContentType from skyflow.vault.connection import InvokeConnectionResponse from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse from skyflow.vault.tokens import DetokenizeResponse, TokenizeResponse creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json") -with open(creds_path, 'r') as file: +with open(creds_path, "r") as file: credentials = json.load(file) TEST_ERROR_MESSAGE = "Test error message." VALID_ENV_CREDENTIALS = credentials -class TestUtils(unittest.TestCase): +class TestUtils(unittest.TestCase): @patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": json.dumps(VALID_ENV_CREDENTIALS)}) def test_get_credentials_env_variable(self): credentials = get_credentials() - credentials_string = credentials.get('credentials_string') - self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n')) + credentials_string = credentials.get("credentials_string") + self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace("\n", "\\n")) + + @patch("skyflow.utils._utils.dotenv.find_dotenv", return_value=None) + @patch.dict(os.environ, {}, clear=True) + def test_get_credentials_no_credentials_raises(self, mock_find_dotenv): + with self.assertRaises(SkyflowError) as context: + get_credentials(config_level_creds=None, common_skyflow_creds=None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) def test_get_credentials_with_config_level_creds(self): test_creds = {"authToken": "test_token"} @@ -58,11 +85,13 @@ def test_get_vault_url_with_invalid_cluster_id(self): valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id)) + self.assertEqual( + context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(valid_vault_id) + ) def test_get_vault_url_with_invalid_env(self): valid_cluster_id = "cluster_id" - valid_env =EnvUrls.DEV + valid_env = EnvUrls.DEV valid_vault_id = "vault123" with self.assertRaises(SkyflowError) as context: url = get_vault_url(valid_cluster_id, valid_env, valid_vault_id) @@ -77,7 +106,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): "http_code": 400, "http_status": "Bad Request", "grpc_code": 3, - "details": ["detail1"] + "details": ["detail1"], } } @@ -88,13 +117,7 @@ def test_handle_json_error_with_dict_data(self, mock_log_and_reject_error): handle_json_error(mock_error, error_dict, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "Dict error message", - 400, - request_id, - "Bad Request", - 3, - ["detail1"], - logger=mock_logger + "Dict error message", 400, request_id, "Bad Request", 3, ["detail1"], logger=mock_logger ) @patch("skyflow.utils._utils.log_and_reject_error") @@ -107,7 +130,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ "http_code": 403, "http_status": "Forbidden", "grpc_code": 7, - "details": ["detail2"] + "details": ["detail2"], } } @@ -118,13 +141,7 @@ def test_handle_json_error_with_error_response_object(self, mock_log_and_reject_ handle_json_error(mock_error, mock_error_response, request_id, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "ErrorResponse message", - 403, - request_id, - "Forbidden", - 7, - ["detail2"], - logger=mock_logger + "ErrorResponse message", 403, request_id, "Forbidden", 7, ["detail2"], logger=mock_logger ) def test_parse_path_params(self): @@ -138,13 +155,56 @@ def test_to_lowercase_keys(self): expected_output = {"key1": "value1", "key2": "value2"} self.assertEqual(to_lowercase_keys(input_dict), expected_output) + def test_r_urlencode_with_list_input(self): + pairs = {} + r_urlencode([], pairs, ["a", "b"]) + self.assertIn("[0]", pairs) + self.assertIn("[1]", pairs) + self.assertEqual(pairs["[0]"], "a") + self.assertEqual(pairs["[1]"], "b") + + def test_r_urlencode_with_tuple_input(self): + pairs = {} + r_urlencode([], pairs, ("x", "y")) + self.assertIn("[0]", pairs) + self.assertEqual(pairs["[0]"], "x") + def test_get_metrics(self): metrics = get_metrics() - self.assertIn('sdk_name_version', metrics) - self.assertIn('sdk_client_device_model', metrics) - self.assertIn('sdk_client_os_details', metrics) - self.assertIn('sdk_runtime_details', metrics) + self.assertIn("sdk_name_version", metrics) + self.assertIn("sdk_client_device_model", metrics) + self.assertIn("sdk_client_os_details", metrics) + self.assertIn("sdk_runtime_details", metrics) + + def test_get_metrics_platform_node_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + with patch("skyflow.utils._utils.platform") as mock_platform: + mock_platform.node.side_effect = OSError("no node") + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_device_model"], "") + utils_module._CACHED_METRICS.clear() + + def test_get_metrics_sys_attribute_exception(self): + import skyflow.utils._utils as utils_module + + utils_module._CACHED_METRICS.clear() + + class _RaisingSys: + @property + def platform(self): + raise RuntimeError("no platform") + + @property + def version(self): + raise RuntimeError("no version") + with patch("skyflow.utils._utils.sys", _RaisingSys()): + metrics = utils_module.get_metrics() + self.assertEqual(metrics["sdk_client_os_details"], "") + self.assertIn("sdk_runtime_details", metrics) + utils_module._CACHED_METRICS.clear() def test_construct_invoke_connection_request_valid(self): mock_connection_request = Mock() @@ -164,7 +224,7 @@ def test_construct_invoke_connection_request_valid(self): self.assertEqual(result.url, expected_url) self.assertEqual(result.method, "POST") - self.assertEqual(result.headers['Content-Type'], ContentType.JSON.value) + self.assertEqual(result.headers["Content-Type"], ContentType.JSON.value) self.assertEqual(result.body, json.dumps(mock_connection_request.body)) @@ -230,9 +290,7 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): mock_connection_request = Mock() mock_connection_request.path_params = {"param1": "value1"} mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} - mock_connection_request.body = { - "name": (None, "John Doe") - } + mock_connection_request.body = {"name": (None, "John Doe")} mock_connection_request.method.value = "POST" mock_connection_request.query_params = {"query": "test"} @@ -242,13 +300,27 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self): self.assertIsInstance(result, PreparedRequest) + def test_parse_insert_response_with_tokens_continue_on_error(self): + api_response = Mock() + api_response.headers = {"x-request-id": "req-1"} + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1", "tokens": {"col1": "tok1"}}]}}, + ] + ) + result = parse_insert_response(api_response, continue_on_error=True) + self.assertEqual(result.inserted_fields[0]["col1"], "tok1") + self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1") + def test_parse_insert_response(self): api_response = Mock() api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - api_response.data = Mock(responses=[ - {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, - {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}} - ]) + api_response.data = Mock( + responses=[ + {"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}}, + {"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}}, + ] + ) result = parse_insert_response(api_response, continue_on_error=True) self.assertEqual(len(result.inserted_fields), 1) self.assertEqual(len(result.errors), 1) @@ -262,17 +334,19 @@ def test_parse_insert_response(self): def test_parse_insert_response_continue_on_error_false(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), - Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}) - ]) + mock_api_response.data = Mock( + records=[ + Mock(skyflow_id="id_1", tokens={"token1": "token_value1"}), + Mock(skyflow_id="id_2", tokens={"token2": "token_value2"}), + ] + ) result = parse_insert_response(mock_api_response, continue_on_error=False) self.assertIsInstance(result, InsertResponse) expected_inserted_fields = [ {"skyflow_id": "id_1", "token1": "token_value1"}, - {"skyflow_id": "id_2", "token2": "token_value2"} + {"skyflow_id": "id_2", "token2": "token_value2"}, ] self.assertEqual(result.inserted_fields, expected_inserted_fields) @@ -283,8 +357,8 @@ def test_parse_update_record_response(self): api_response.skyflow_id = "id1" api_response.tokens = {"token1": "value1"} result = parse_update_record_response(api_response) - self.assertEqual(result.updated_field['skyflow_id'], "id1") - self.assertEqual(result.updated_field['token1'], "value1") + self.assertEqual(result.updated_field["skyflow_id"], "id1") + self.assertEqual(result.updated_field["token1"], "value1") def test_parse_delete_response_successful(self): mock_api_response = Mock() @@ -302,42 +376,39 @@ def test_parse_delete_response_successful(self): def test_parse_get_response_successful(self): mock_api_response = Mock() mock_api_response.records = [ - Mock(fields={'field1': 'value1', 'field2': 'value2'}), - Mock(fields={'field1': 'value3', 'field2': 'value4'}) + Mock(fields={"field1": "value1", "field2": "value2"}), + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_get_response(mock_api_response) self.assertIsInstance(result, GetResponse) - expected_data = [ - {'field1': 'value1', 'field2': 'value2'}, - {'field1': 'value3', 'field2': 'value4'} - ] + expected_data = [{"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"}] self.assertEqual(result.data, expected_data) - # self.assertEqual(result.errors, None) + self.assertIsNone(result.errors) def test_parse_detokenize_response_with_mixed_records(self): mock_api_response = Mock() mock_api_response.headers = {"x-request-id": "12345", "content-type": "application/json"} - mock_api_response.data = Mock(records=[ - Mock(token="token1", value="value1", value_type="Type1", error=None), - Mock(token="token2", value=None, value_type=None, error="Some error"), - Mock(token="token3", value="value3", value_type="Type2", error=None), - ]) + mock_api_response.data = Mock( + records=[ + Mock(token="token1", value="value1", value_type="Type1", error=None), + Mock(token="token2", value=None, value_type=None, error="Some error"), + Mock(token="token3", value="value3", value_type="Type2", error=None), + ] + ) result = parse_detokenize_response(mock_api_response) self.assertIsInstance(result, DetokenizeResponse) expected_detokenized_fields = [ {"token": "token1", "value": "value1", "type": "Type1"}, - {"token": "token3", "value": "value3", "type": "Type2"} + {"token": "token3", "value": "value3", "type": "Type2"}, ] - expected_errors = [ - {"token": "token2", "error": "Some error", "request_id": "12345"} - ] + expected_errors = [{"token": "token2", "error": "Some error", "request_id": "12345"}] self.assertEqual(result.detokenized_fields, expected_detokenized_fields) self.assertEqual(result.errors, expected_errors) @@ -353,11 +424,7 @@ def test_parse_tokenize_response_with_valid_records(self): result = parse_tokenize_response(mock_api_response) self.assertIsInstance(result, TokenizeResponse) - expected_tokenized_fields = [ - {"token": "token1"}, - {"token": "token2"}, - {"token": "token3"} - ] + expected_tokenized_fields = [{"token": "token1"}, {"token": "token2"}, {"token": "token3"}] self.assertEqual(result.tokenized_fields, expected_tokenized_fields) @@ -365,7 +432,7 @@ def test_parse_query_response_with_valid_records(self): mock_api_response = Mock() mock_api_response.records = [ Mock(fields={"field1": "value1", "field2": "value2"}), - Mock(fields={"field1": "value3", "field2": "value4"}) + Mock(fields={"field1": "value3", "field2": "value4"}), ] result = parse_query_response(mock_api_response) @@ -374,7 +441,7 @@ def test_parse_query_response_with_valid_records(self): expected_fields = [ {"field1": "value1", "field2": "value2", "tokenized_data": {}}, - {"field1": "value3", "field2": "value4", "tokenized_data": {}} + {"field1": "value3", "field2": "value4", "tokenized_data": {}}, ] self.assertEqual(result.fields, expected_fields) @@ -382,7 +449,7 @@ def test_parse_query_response_with_valid_records(self): @patch("requests.Response") def test_parse_invoke_connection_response_successful(self, mock_response): mock_response.status_code = 200 - mock_response.content = json.dumps({"key": "value"}).encode('utf-8') + mock_response.content = json.dumps({"key": "value"}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} result = parse_invoke_connection_response(mock_response) @@ -394,19 +461,23 @@ def test_parse_invoke_connection_response_successful(self, mock_response): @patch("requests.Response") def test_parse_invoke_connection_response_json_decode_error(self, mock_response): - + """Test that non-JSON content in successful response is returned as string.""" mock_response.status_code = 200 - mock_response.content = "Non-JSON Content".encode('utf-8') + mock_response.content = "Non-JSON Content".encode("utf-8") + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() - with self.assertRaises(SkyflowError) as context: - parse_invoke_connection_response(mock_response) + result = parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Non-JSON Content")) + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Non-JSON Content") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) @patch("requests.Response") def test_parse_invoke_connection_response_http_error_with_json_error_message(self, mock_response): mock_response.status_code = 404 - mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode('utf-8') + mock_response.content = json.dumps({"error": {"message": "Not Found"}}).encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("404 Error") @@ -417,10 +488,38 @@ def test_parse_invoke_connection_response_http_error_with_json_error_message(sel self.assertEqual(context.exception.message, "Not Found") self.assertEqual(context.exception.request_id, "1234") + @patch("requests.Response") + def test_parse_invoke_connection_response_with_error_from_client_header(self, mock_response): + from requests.models import HTTPError + + mock_response.status_code = 400 + mock_response.content = json.dumps( + { + "error": { + "message": "Client error", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": 3, + "details": None, + } + } + ).encode("utf-8") + mock_response.headers = { + "x-request-id": "rid-1", + "error-from-client": "true", + } + mock_response.raise_for_status.side_effect = HTTPError("400") + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + err = context.exception + self.assertEqual(err.message, "Client error") + self.assertIsNotNone(err.details) + self.assertTrue(any(d.get("error_from_client") is True for d in err.details)) + @patch("requests.Response") def test_parse_invoke_connection_response_http_error_without_json_error_message(self, mock_response): mock_response.status_code = 500 - mock_response.content = "Internal Server Error".encode('utf-8') + mock_response.content = "Internal Server Error".encode("utf-8") mock_response.headers = {"x-request-id": "1234"} mock_response.raise_for_status.side_effect = HTTPError("500 Error") @@ -428,37 +527,32 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Internal Server Error")) + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_json_error(self, mock_log_and_reject_error): mock_error = Mock() - mock_error.headers = { - 'x-request-id': '1234', - 'content-type': 'application/json' - } - mock_error.body = json.dumps({ - "error": { - "message": "JSON error occurred.", - "http_code": 400, - "http_status": "Bad Request", - "grpc_code": "8", - "details": "Detailed message" + mock_error.headers = {"x-request-id": "1234", "content-type": "application/json"} + mock_error.body = json.dumps( + { + "error": { + "message": "JSON error occurred.", + "http_code": 400, + "http_status": "Bad Request", + "grpc_code": "8", + "details": "Detailed message", + } } - }).encode('utf-8') + ).encode("utf-8") mock_logger = Mock() handle_exception(mock_error, mock_logger) mock_log_and_reject_error.assert_called_once_with( - "JSON error occurred.", - 400, - "1234", - "Bad Request", - "8", - "Detailed message", - logger=mock_logger + "JSON error occurred.", 400, "1234", "Bad Request", "8", "Detailed message", logger=mock_logger ) def test_validate_api_key_valid_key(self): @@ -494,12 +588,7 @@ def test_parse_deidentify_text_response(self): mock_entity.value = "sensitive_value" mock_entity.entity_type = "EMAIL" mock_entity.entity_scores = {"EMAIL": 0.95} - mock_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 - ) + mock_entity.location = Mock(start_index=10, end_index=20, start_index_processed=15, end_index_processed=25) mock_api_response = Mock() mock_api_response.processed_text = "Sample processed text" @@ -556,10 +645,7 @@ def test__convert_detected_entity_to_entity_info(self): mock_detected_entity.entity_type = "EMAIL" mock_detected_entity.entity_scores = {"EMAIL": 0.95} mock_detected_entity.location = Mock( - start_index=10, - end_index=20, - start_index_processed=15, - end_index_processed=25 + start_index=10, end_index=20, start_index_processed=15, end_index_processed=25 ) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -580,12 +666,7 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): mock_detected_entity.value = None mock_detected_entity.entity_type = "UNKNOWN" mock_detected_entity.entity_scores = {} - mock_detected_entity.location = Mock( - start_index=0, - end_index=0, - start_index_processed=0, - end_index_processed=0 - ) + mock_detected_entity.location = Mock(start_index=0, end_index=0, start_index_processed=0, end_index_processed=0) result = convert_detected_entity_to_entity_info(mock_detected_entity) @@ -597,3 +678,925 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.text_index.end, 0) self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_connect_error(self, mock_log_and_reject_error): + """Test handling httpx.ConnectError.""" + import httpx + + mock_error = httpx.ConnectError("Connection refused") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Connection refused", SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_headers_attribute(self, mock_log_and_reject_error): + """Test handling error without headers attribute.""" + mock_error = Exception("Generic error") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Generic error", SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_body_attribute(self, mock_log_and_reject_error): + """Test handling error without body attribute.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "12345"} + delattr(mock_error, "body") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_text_plain_error(self, mock_log_and_reject_error): + """Test handling text/plain content type error.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "1234", "content-type": "text/plain"} + mock_error.body = "Plain text error message" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with("Plain text error message", 500, "1234", logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_error): + """Test handling generic error with unknown content type.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "1234", "content-type": "application/xml"} + mock_error.body = "XML error" + mock_error.status = 503 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, "1234", logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_content_type(self, mock_log_and_reject_error): + """Test handling error without content-type header.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "1234"} + mock_error.body = "Some error" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 500, "1234", logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): + """Test handling JSON error when data is a JSON string.""" + error_json_string = json.dumps( + { + "error": { + "message": "String JSON error", + "http_code": 422, + "http_status": "Unprocessable Entity", + "grpc_code": 3, + "details": ["validation failed"], + } + } + ) + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-3" + + handle_json_error(mock_error, error_json_string, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "String JSON error", 422, request_id, "Unprocessable Entity", 3, ["validation failed"], logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_invalid_json(self, mock_log_and_reject_error): + """Test handling JSON decode error.""" + invalid_json = "This is not valid JSON" + mock_error = Mock() + mock_error.status = 500 + mock_logger = Mock() + request_id = "test-request-id-4" + + handle_json_error(mock_error, invalid_json, request_id, mock_logger) + + # Should call with INVALID_JSON_RESPONSE error + mock_log_and_reject_error.assert_called_once() + self.assertEqual(mock_log_and_reject_error.call_args[0][0], SkyflowMessages.Error.INVALID_JSON_RESPONSE.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_missing_error_field(self, mock_log_and_reject_error): + """Test handling JSON error with missing error field.""" + error_dict = {"message": "Error without error wrapper"} + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-5" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + # Should use defaults for missing fields + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + # Default message when error field is missing + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + # Default status code + self.assertEqual(args[1], 500) + self.assertEqual(args[2], request_id) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_text_error_with_status(self, mock_log_and_reject_error): + """Test handle_text_error extracts status correctly.""" + mock_error = Mock() + mock_error.status = 404 + mock_logger = Mock() + request_id = "test-request-id-6" + error_data = "Resource not found" + + from skyflow.utils._utils import handle_text_error + + handle_text_error(mock_error, error_data, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with("Resource not found", 404, request_id, logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_generic_error_with_status(self, mock_log_and_reject_error): + """Test handle_generic_error_with_status.""" + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-7" + status = 503 + + from skyflow.utils._utils import handle_generic_error_with_status + + handle_generic_error_with_status(mock_error, request_id, status, mock_logger) + + mock_log_and_reject_error.assert_called_once_with(str(mock_error), 503, request_id, logger=mock_logger) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_none_error(self, mock_log_and_reject_error): + """Test handling None error object.""" + mock_logger = Mock() + + handle_exception(None, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + SkyflowMessages.ErrorCodes.SERVER_ERROR.value, + None, + logger=mock_logger, + ) + + # failed + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_error): + """Test handling empty string error.""" + mock_logger = Mock() + mock_error = Mock() + mock_error.headers = None + mock_error.body = None + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + # Should use str(error) or default message + self.assertEqual(mock_log_and_reject_error.call_args[0][1], SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_responses_key(self, mock_log_and_reject_error): + """Test handle_json_error when body has 'responses' key (batch/continue_on_error path).""" + error_dict = { + "responses": [ + {"Status": 400, "Body": {"error": "record not found"}}, + {"Status": 400, "Body": {"error": "invalid field"}}, + ] + } + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-responses" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertIn("record not found", args[0]) + self.assertIn("invalid field", args[0]) + self.assertEqual(args[1], 400) + self.assertIsNone(args[3]) # http_status + self.assertIsNone(args[4]) # grpc_code + self.assertEqual(args[5], []) # details + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_responses_no_error_messages(self, mock_log_and_reject_error): + """Test handle_json_error with responses key but no error body — falls back to default message.""" + error_dict = { + "responses": [ + {"Status": 200, "Body": {"records": [{"skyflow_id": "abc"}]}}, + ] + } + mock_error = Mock() + request_id = "test-request-id-responses-empty" + + handle_json_error(mock_error, error_dict, request_id, None) + + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): + """Test handling JSON error when data is bytes.""" + error_dict = {"error": {"message": "Bytes error", "http_code": 401, "http_status": "Unauthorized"}} + error_bytes = json.dumps(error_dict).encode("utf-8") + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-8" + + handle_json_error(mock_error, error_bytes, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Bytes error", 401, request_id, "Unauthorized", None, [], logger=mock_logger + ) + + # Add these new test methods to the TestUtils class: + + def test_construct_invoke_connection_request_with_no_headers(self): + """Test construct_invoke_connection_request when headers are None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param1": "value1"} + mock_connection_request.headers = None + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {"query": "test"} + + connection_url = "https://example.com/{param1}/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Headers should be None when not provided + self.assertIsNone(result.headers.get("Content-Type")) + + def test_construct_invoke_connection_request_with_xml_content_type(self): + """Test construct_invoke_connection_request with XML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/xml"} + mock_connection_request.body = {"root": {"child": "value"}} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers["content-type"], "application/xml") + # Body should be converted to XML + self.assertIn("", result.body) + self.assertIn("value", result.body) + + def test_construct_invoke_connection_request_with_html_content_type(self): + """Test construct_invoke_connection_request with HTML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "text/html"} + mock_connection_request.body = {"message": "Hello"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers["content-type"], "text/html") + # Body should be JSON string for HTML + self.assertEqual(result.body, json.dumps({"message": "Hello"})) + + def test_construct_invoke_connection_request_multipart_removes_content_type(self): + """Test that Content-Type is removed for multipart/form-data.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} + mock_connection_request.body = {"field1": "value1", "field2": "value2"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Content-Type should be auto-generated by requests library + self.assertIn("multipart/form-data", result.headers.get("Content-Type", "")) + self.assertIn("boundary=", result.headers.get("Content-Type", "")) + + def test_construct_invoke_connection_request_with_no_body(self): + """Test construct_invoke_connection_request when body is None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertIsNone(result.body) + + def test_get_data_from_content_type_url_encoded(self): + """Test get_data_from_content_type with URL encoded content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key1": "value1", "key2": "value2"} + content_type = ContentType.URLENCODED.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, "key1=value1&key2=value2") + self.assertEqual(files, {}) + + def test_get_data_from_content_type_form_data(self): + """Test get_data_from_content_type with form data content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"field1": "value1", "field2": "value2"} + content_type = ContentType.FORMDATA.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIsNone(converted_data) + self.assertEqual(files["field1"], (None, "value1")) + self.assertEqual(files["field2"], (None, "value2")) + + def test_get_data_from_content_type_json(self): + """Test get_data_from_content_type with JSON content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = ContentType.JSON.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_dict(self): + """Test get_data_from_content_type with XML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"root": {"child": "value"}} + content_type = "application/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIn("", converted_data) + self.assertIn("value", converted_data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_string(self): + """Test get_data_from_content_type with XML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "value" + content_type = "text/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_dict(self): + """Test get_data_from_content_type with HTML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"message": "Hello"} + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_string(self): + """Test get_data_from_content_type with HTML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "Hello" + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_dict(self): + """Test get_data_from_content_type with unknown content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = "application/custom" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_string(self): + """Test get_data_from_content_type with unknown content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "plain text data" + content_type = "text/plain" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_dict_to_xml_simple_dict(self): + """Test dict_to_xml with simple dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"name": "John", "age": "30"} + result = dict_to_xml(data) + + self.assertIn("John", result) + self.assertIn("30", result) + self.assertTrue(result.startswith("")) + self.assertTrue(result.endswith("")) + + def test_dict_to_xml_nested_dict(self): + """Test dict_to_xml with nested dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"person": {"name": "John", "age": "30"}} + result = dict_to_xml(data) + + self.assertIn("", result) + self.assertIn("John", result) + self.assertIn("30", result) + + def test_dict_to_xml_with_list(self): + """Test dict_to_xml with list values.""" + from skyflow.utils._utils import dict_to_xml + + data = {"items": ["item1", "item2", "item3"]} + result = dict_to_xml(data) + + self.assertIn("item1", result) + self.assertIn("item2", result) + self.assertIn("item3", result) + + @patch("requests.Response") + def test_parse_invoke_connection_response_xml_content(self, mock_response): + """Test parsing XML response content.""" + mock_response.status_code = 200 + mock_response.content = b"success" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/xml"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_url_encoded_content(self, mock_response): + """Test parsing URL encoded response content.""" + mock_response.status_code = 200 + mock_response.content = b"card_number=4111111111111111&cvv=123" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/x-www-form-urlencoded"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "card_number=4111111111111111&cvv=123") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_content(self, mock_response): + """Test parsing HTML response content.""" + mock_response.status_code = 200 + mock_response.content = b"Success" + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_error(self, mock_response): + """Test parsing HTML error response.""" + html_error = "

Error 500

" + mock_response.status_code = 500 + mock_response.content = html_error.encode("utf-8") + mock_response.headers = {"x-request-id": "1234", "content-type": "text/html"} + mock_response.raise_for_status = Mock(side_effect=HTTPError("500 Error")) + + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.API_ERROR.value.format(500)) + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") + + @patch("requests.Response") + def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, mock_response): + """Test that JSON decode error falls back to returning string content.""" + mock_response.status_code = 200 + mock_response.content = b"Not valid JSON but still success" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/json"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Not valid JSON but still success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_json(self, mock_response): + """Test parsing response with no content-type but valid JSON.""" + mock_response.status_code = 200 + mock_response.content = json.dumps({"success": True}).encode("utf-8") + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, {"success": True}) + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_text(self, mock_response): + """Test parsing response with no content-type and non-JSON content.""" + mock_response.status_code = 200 + mock_response.content = b"Plain text response" + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Plain text response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_bytes_content(self, mock_response): + """Test parsing response with bytes content.""" + mock_response.status_code = 200 + mock_response.content = b"Binary data response" + mock_response.headers = {"x-request-id": "1234", "content-type": "application/octet-stream"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Binary data response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + def test_construct_invoke_connection_request_headers_json_error(self): + """Test exception handling when json.dumps fails for headers.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + + class UnserializableObject: + def __repr__(self): + raise TypeError("Object is not JSON serializable") + + mock_connection_request.headers = {"key": UnserializableObject()} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("json.dumps", side_effect=TypeError("Object is not JSON serializable")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_headers_generic_exception(self): + """Test generic exception handling for headers processing.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/json"} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("skyflow.utils._utils.to_lowercase_keys", side_effect=Exception("Generic error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_processing_exception(self): + """Test exception handling when body processing fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("skyflow.utils._utils.get_data_from_content_type", side_effect=Exception("Body processing error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_json_dumps_exception(self): + """Test exception handling when json.dumps fails in get_data_from_content_type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + + class UnserializableObject: + pass + + mock_connection_request.body = {"key": UnserializableObject()} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_invalid_url_exception(self): + """Test exception handling when requests.Request.prepare() fails with invalid URL.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("requests.Request") as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Invalid URL structure") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_prepare_exception(self): + """Test exception handling when prepare() method fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch("requests.Request") as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Prepare failed") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_URL.value.format(connection_url)) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_not_dict_raises_error(self): + """Test that non-dict body raises SkyflowError which is caught and re-raised.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = "not a dict" # Invalid body type + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + @patch("skyflow.utils._utils.validate_invoke_connection_params") + def test_construct_invoke_connection_request_validation_exception(self, mock_validate): + """Test that validation exceptions are properly propagated.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param": "value"} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {"query": "value"} + + connection_url = "https://example.com/endpoint" + + mock_validate.side_effect = SkyflowError("Validation failed", 400) + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, "Validation failed") + self.assertEqual(context.exception.http_code, 400) + + def test_generate_bearer_token_invalid_token_uri_type(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": 12345, # invalid type + } + + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_bearer_token(tmp.name) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + # Patch AuthClient and jwt.encode to avoid real HTTP and signing + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + generate_bearer_token(tmp.name, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_bearer_token_from_creds_invalid_token_uri_type(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_bearer_token_from_creds(creds_str) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_bearer_token_from_creds_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("skyflow.service_account._utils.get_signed_jwt") as mock_get_signed_jwt: + mock_get_signed_jwt.return_value = "signed" + with patch("skyflow.service_account._utils.AuthClient") as mock_auth_client: + mock_auth_api = mock_auth_client.return_value.get_auth_api.return_value + mock_auth_api.authentication_service_get_auth_token.return_value = type( + "obj", (), {"access_token": "token", "token_type": "bearer"} + ) + generate_bearer_token_from_creds(creds_str, options) + args, kwargs = mock_get_signed_jwt.call_args + self.assertEqual(args[3], options["token_uri"]) + + def test_generate_signed_data_tokens_invalid_token_uri_type(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + options = {"data_tokens": ["token1"]} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens(tmp.name, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_type(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": 12345} + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_from_creds_invalid_token_uri_url(self): + creds = {"privateKey": "private_key", "clientID": "client_id", "keyID": "key_id", "tokenURI": "not_a_url"} + options = {"data_tokens": ["token1"]} + creds_str = json.dumps(creds) + with self.assertRaises(SkyflowError) as context: + generate_signed_data_tokens_from_creds(creds_str, options) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_generate_signed_data_tokens_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + json.dump(creds, tmp) + tmp.flush() + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens(tmp.name, options) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") + + def test_generate_signed_data_tokens_from_creds_options_override_token_uri(self): + creds = { + "privateKey": "private_key", + "clientID": "client_id", + "keyID": "key_id", + "tokenURI": "https://valid-url.com", + } + options = {"data_tokens": ["token1"], "token_uri": "https://another-valid-url.com"} + creds_str = json.dumps(creds) + with patch("jwt.encode") as mock_jwt_encode: + mock_jwt_encode.return_value = "signed" + result = generate_signed_data_tokens_from_creds(creds_str, options) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 1) + self.assertEqual(result[0][0], "token1") + self.assertEqual(result[0][1], "signed_token_signed") diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index b1247ebc..9ee05cef 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -12,13 +12,15 @@ validate_insert_request, validate_delete_request, validate_query_request, validate_get_detect_run_request, validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request, validate_invoke_connection_params, - validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request + validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request, + validate_file_upload_request ) from skyflow.utils import SkyflowMessages from skyflow.utils.enums import DetectEntities, RedactionType from skyflow.vault.data import GetRequest, UpdateRequest from skyflow.vault.detect import DeidentifyTextRequest, Transformations, DateTransformation, ReidentifyTextRequest, \ - FileInput, DeidentifyFileRequest + FileInput, DeidentifyFileRequest, Bleep +from skyflow.vault.data._file_upload_request import FileUploadRequest from skyflow.vault.tokens import DetokenizeRequest from skyflow.vault.connection._invoke_connection_request import InvokeConnectionRequest @@ -116,7 +118,7 @@ def test_validate_credentials_with_expired_token(self): with patch('skyflow.service_account.is_expired', return_value=True): with self.assertRaises(SkyflowError) as context: validate_credentials(self.logger, credentials) - self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value) def test_validate_credentials_empty_credentials(self): credentials = {} @@ -205,15 +207,6 @@ def test_validate_update_vault_config_valid(self): } self.assertTrue(validate_update_vault_config(self.logger, config)) - def test_validate_update_vault_config_missing_credentials(self): - config = { - "vault_id": "vault123", - "cluster_id": "cluster123" - } - with self.assertRaises(SkyflowError) as context: - validate_update_vault_config(self.logger, config) - self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123")) - def test_validate_update_vault_config_invalid_cluster_id(self): config = { "vault_id": "vault123", @@ -226,6 +219,18 @@ def test_validate_update_vault_config_invalid_cluster_id(self): validate_update_vault_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format("vault123")) + def test_validate_update_vault_config_missing_credentials(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123") + ) + def test_validate_connection_config_valid(self): config = { "connection_id": "conn123", @@ -259,6 +264,18 @@ def test_validate_connection_config_empty_connection_id(self): validate_connection_config(self.logger, config) self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value) + def test_validate_connection_config_missing_credentials(self): + config = { + "connection_id": "conn123", + "connection_url": "https://example.com", + } + with self.assertRaises(SkyflowError) as context: + validate_connection_config(self.logger, config) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", "conn123") + ) + def test_validate_update_connection_config_valid(self): config = { "connection_id": "conn123", @@ -1040,7 +1057,436 @@ def test_validate_detokenize_request_invalid_continue_on_error_type(self): self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value) def test_validate_detokenize_request_invalid_redaction_type(self): - request = DetokenizeRequest(data=[{"token": "token123", "redaction": "invalid"}], continue_on_error=False) + request = DetokenizeRequest(data=[{"token": "token123", "redaction_type": "invalid"}], continue_on_error=False) with self.assertRaises(SkyflowError) as context: validate_detokenize_request(self.logger, request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) + + def test_validate_detokenize_request_deprecated_redaction_key_emits_warn(self): + from unittest.mock import patch + request = DetokenizeRequest(data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT}], continue_on_error=False) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_called_once() + self.assertIn("redaction_type", mock_warn.call_args[0][0]) + + def test_validate_detokenize_request_both_keys_prioritizes_redaction_type_and_warns(self): + from unittest.mock import patch + request = DetokenizeRequest( + data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT, "redaction_type": RedactionType.MASKED}], + continue_on_error=False + ) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_called_once() + + def test_validate_detokenize_request_redaction_type_only_no_warn(self): + from unittest.mock import patch + request = DetokenizeRequest(data=[{"token": "token123", "redaction_type": RedactionType.PLAIN_TEXT}], continue_on_error=False) + with patch('skyflow.utils.validations._validations.log_warn') as mock_warn: + validate_detokenize_request(self.logger, request) + mock_warn.assert_not_called() + + + def test_validate_deidentify_file_request_wait_time_negative(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=-1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_greater_than_64(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=65, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_lower(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=0, + entities=[DetectEntities.SSN] + ) + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_upper(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_float(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=32.5, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_float_out_of_range(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64.1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + def test_validate_credentials_with_valid_token_uri(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + } + # Should not raise + validate_credentials(self.logger, credentials) + + def test_validate_credentials_with_invalid_token_uri_type(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 # Not a string + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_credentials_with_invalid_token_uri_url(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_valid_token_uri(self): + from skyflow.utils.enums import Env + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "https://valid-url.com" + }, + "env": Env.DEV + } + # Should not raise + self.assertTrue(validate_update_vault_config(self.logger, config)) + + def test_validate_update_vault_config_with_invalid_token_uri_type(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": 12345 + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + def test_validate_update_vault_config_with_invalid_token_uri_url(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef", + "token_uri": "not_a_url" + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_URI.value) + + # --- validate_file_from_request --- + + def test_validate_file_from_request_none_input(self): + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_INPUT.value) + + def test_validate_file_from_request_file_without_name_attr(self): + file_obj = MagicMock(spec=[]) # no attributes at all + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_file_with_empty_name(self): + file_obj = MagicMock() + file_obj.name = " " # whitespace-only name + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_TYPE.value) + + def test_validate_file_from_request_extension_only_name(self): + file_obj = MagicMock() + # A trailing-slash path gives os.path.basename() == "", so splitext returns ("", "") + file_obj.name = "/some/directory/" + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_from_request_empty_string_file_path(self): + file_input = MagicMock() + file_input.file = None + file_input.file_path = "" # empty string — has_file_path=True, so goes to elif branch + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) + + # --- validate_deidentify_file_request bleep sub-fields --- + + def test_validate_deidentify_file_request_invalid_bleep_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, bleep="not_a_bleep") + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_TYPE.value) + + def test_validate_deidentify_file_request_invalid_bleep_gain(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(gain="loud") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_GAIN.value) + + def test_validate_deidentify_file_request_invalid_bleep_frequency(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(frequency="high") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_FREQUENCY.value) + + def test_validate_deidentify_file_request_invalid_bleep_start_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(start_padding="early") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_START_PADDING.value) + + def test_validate_deidentify_file_request_invalid_bleep_stop_padding(self): + file_input = FileInput(file_path=self.temp_file_path) + bleep = Bleep(stop_padding="late") + request = DeidentifyFileRequest(file=file_input, bleep=bleep) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value) + + # --- validate_deidentify_file_request output_directory --- + + def test_validate_deidentify_file_request_invalid_output_directory_type(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=123) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value) + + def test_validate_deidentify_file_request_output_directory_not_found(self): + file_input = FileInput(file_path=self.temp_file_path) + nonexistent = "/tmp/skyflow_nonexistent_dir_12345" + request = DeidentifyFileRequest(file=file_input, output_directory=nonexistent) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(nonexistent) + ) + + def test_validate_deidentify_file_request_valid_output_directory(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest(file=file_input, output_directory=self.temp_dir_path) + validate_deidentify_file_request(self.logger, request) + + # --- validate_file_upload_request --- + + def test_validate_file_upload_request_none(self): + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_table(self): + request = MagicMock() + request.table = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_file_upload_request_empty_table(self): + request = MagicMock() + request.table = " " + request.column_name = "file_col" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TABLE_VALUE.value) + + def test_validate_file_upload_request_none_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = None + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(None)) + ) + + def test_validate_file_upload_request_empty_column_name(self): + request = MagicMock() + request.table = "test_table" + request.skyflow_id = None + request.column_name = "" + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type("")) + ) + + def test_validate_file_upload_request_empty_skyflow_id(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + skyflow_id=" ", + file_path=self.temp_file_path + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD") + ) + + def test_validate_file_upload_request_invalid_file_object_seek(self): + file_obj = MagicMock() + file_obj.seek.side_effect = OSError("seek failed") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=file_obj + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_OBJECT.value) + + def test_validate_file_upload_request_valid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path=self.temp_file_path + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_invalid_file_path(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_path="/nonexistent/path/file.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_PATH.value) + + def test_validate_file_upload_request_valid_base64(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded, + file_name="sample.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_base64_without_file_name(self): + import base64 + encoded = base64.b64encode(b"file content").decode("utf-8") + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64=encoded + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_NAME.value) + + def test_validate_file_upload_request_invalid_base64_string(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col", + base64="not-valid-base64!!!", + file_name="sample.txt" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_BASE64_STRING.value) + + def test_validate_file_upload_request_valid_file_object(self): + with open(self.temp_file_path, "rb") as f: + request = FileUploadRequest( + table="test_table", + column_name="file_col", + file_object=f + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_upload_request_missing_file_source(self): + request = FileUploadRequest( + table="test_table", + column_name="file_col" + ) + with self.assertRaises(SkyflowError) as context: + validate_file_upload_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + # --- validate_deidentify_text_request transformations --- + + def test_validate_deidentify_text_request_invalid_transformations(self): + request = DeidentifyTextRequest( + text="test text", + transformations="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value) + + # --- validate_reidentify_text_request masked_entities --- + + def test_validate_reidentify_text_request_invalid_masked_entities(self): + request = ReidentifyTextRequest( + text="test text", + masked_entities="invalid_type" + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_MASKED_ENTITIES_IN_REIDENTIFY.value) \ No newline at end of file diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 9d0d2520..4df508c7 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -1,5 +1,8 @@ import unittest -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock + +from skyflow.error import SkyflowError +from skyflow.utils import SkyflowMessages from skyflow.vault.client.client import VaultClient CONFIG = { @@ -321,4 +324,4 @@ def test_get_query_api(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/vault/connection/__init__.py b/tests/vault/connection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/connection/test_responses.py b/tests/vault/connection/test_responses.py new file mode 100644 index 00000000..72bd0c56 --- /dev/null +++ b/tests/vault/connection/test_responses.py @@ -0,0 +1,26 @@ +import unittest +from skyflow.vault.connection._invoke_connection_response import InvokeConnectionResponse + + +class TestInvokeConnectionResponse(unittest.TestCase): + def test_repr(self): + r = InvokeConnectionResponse(data={"key": "val"}, metadata={"m": 1}, errors=None) + self.assertIn("ConnectionResponse", repr(r)) + + def test_str(self): + r = InvokeConnectionResponse(data={"key": "val"}) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = InvokeConnectionResponse() + self.assertIsNone(r.data) + self.assertEqual(r.metadata, {}) + self.assertIsNone(r.errors) + + def test_metadata_defaults_to_empty_dict_when_falsy(self): + r = InvokeConnectionResponse(metadata=None) + self.assertEqual(r.metadata, {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__audit_binlookup.py b/tests/vault/controller/test__audit_binlookup.py new file mode 100644 index 00000000..978eb032 --- /dev/null +++ b/tests/vault/controller/test__audit_binlookup.py @@ -0,0 +1,27 @@ +import unittest +from skyflow.vault.controller._audit import Audit +from skyflow.vault.controller._bin_look_up import BinLookUp + + +class TestAudit(unittest.TestCase): + def test_instantiation(self): + a = Audit() + self.assertIsNotNone(a) + + def test_list_returns_none(self): + a = Audit() + self.assertIsNone(a.list()) + + +class TestBinLookUp(unittest.TestCase): + def test_instantiation(self): + b = BinLookUp() + self.assertIsNotNone(b) + + def test_get_returns_none(self): + b = BinLookUp() + self.assertIsNone(b.get()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 4ccad1c7..f073264c 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -1,9 +1,11 @@ +import json import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock import requests from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages, parse_invoke_connection_response -from skyflow.utils.enums import RequestMethod +from skyflow.utils._utils import get_data_from_content_type, construct_invoke_connection_request +from skyflow.utils.enums import RequestMethod, ContentType from skyflow.utils._version import SDK_VERSION from skyflow.vault.connection import InvokeConnectionRequest from skyflow.vault.controller import Connection @@ -30,10 +32,16 @@ def setUp(self): self.mock_vault_client = Mock() self.mock_vault_client.get_config.return_value = VAULT_CONFIG self.mock_vault_client.get_bearer_token.return_value = VALID_BEARER_TOKEN + self.mock_vault_client.get_logger.return_value = Mock() + self.mock_vault_client.get_common_skyflow_credentials.return_value = None self.connection = Connection(self.mock_vault_client) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_success(self, mock_send): + def test_invoke_success(self, mock_send, mock_get_credentials): + # Mock get_credentials to return credentials + mock_get_credentials.return_value = {"api_key": "test_api_key"} + # Mocking successful response mock_response = Mock() mock_response.status_code = SUCCESS_STATUS_CODE @@ -60,9 +68,36 @@ def test_invoke_success(self, mock_send): } self.assertEqual(vars(response), expected_response) self.mock_vault_client.get_bearer_token.assert_called_once() + mock_get_credentials.assert_called_once() + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_invalid_headers(self, mock_send): + def test_invoke_with_x_skyflow_authorization_already_present(self, mock_send, mock_get_credentials): + """Test that X-Skyflow-Authorization is not overwritten if already present in headers.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + custom_auth = "custom_bearer_token" + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers={"x-skyflow-authorization": custom_auth} + ) + + response = self.connection.invoke(request) + + # Verify bearer token from vault_client is NOT used + self.assertIsNotNone(response) + + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_headers(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=VALID_BODY, @@ -75,8 +110,10 @@ def test_invoke_invalid_headers(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) - @patch('requests.Session.send') - def test_invoke_invalid_body(self, mock_send): + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_body(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=INVALID_BODY, @@ -89,11 +126,16 @@ def test_invoke_invalid_body(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_request_error(self, mock_send): + def test_invoke_request_error(self, mock_send, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_response = Mock() mock_response.status_code = FAILURE_STATUS_CODE - mock_response.content = ERROR_RESPONSE_CONTENT + mock_response.content = ERROR_RESPONSE_CONTENT.encode('utf-8') # Convert to bytes + mock_response.headers = {"x-request-id": "test-request-id"} + mock_response.raise_for_status.side_effect = requests.HTTPError("400 Error") mock_send.return_value = mock_response request = InvokeConnectionRequest( @@ -106,9 +148,100 @@ def test_invoke_request_error(self, mock_send): with self.assertRaises(SkyflowError) as context: self.connection.invoke(request) - self.assertEqual(context.exception.message, f'Skyflow Python SDK {SDK_VERSION} Response {ERROR_RESPONSE_CONTENT} is not valid JSON.') - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(ERROR_RESPONSE_CONTENT)) - self.assertEqual(context.exception.http_code, 400) + + expected_message = SkyflowMessages.Error.API_ERROR.value.format(FAILURE_STATUS_CODE) + self.assertEqual(context.exception.message, expected_message) + self.assertEqual(context.exception.http_code, FAILURE_STATUS_CODE) + self.assertEqual(context.exception.request_id, "test-request-id") + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_send_exception(self, mock_send, mock_get_credentials): + """Test handling of generic exception from session.send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_send.side_effect = Exception("Network error") + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_skyflow_error_re_raised(self, mock_send, mock_get_credentials): + """Test that SkyflowError is re-raised without wrapping.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + original_error = SkyflowError("Original error", 401) + mock_send.side_effect = original_error + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + # Should be the same original error + self.assertEqual(context.exception.message, "Original error") + self.assertEqual(context.exception.http_code, 401) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_close_called(self, mock_send, mock_get_credentials): + """Test that session.close() is called after send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + with patch('requests.Session.close') as mock_close: + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify close was called + mock_close.assert_called_once() + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.get_metrics') + @patch('requests.Session.send') + def test_invoke_adds_sky_metadata_header(self, mock_send, mock_get_metrics, mock_get_credentials): + """Test that sky-metadata header is added to request.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_get_metrics.return_value = {"sdk_version": SDK_VERSION} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify get_metrics was called + mock_get_metrics.assert_called_once() + self.assertIsNotNone(response) def test_parse_invoke_connection_response_error_from_client(self): mock_response = Mock(spec=requests.Response) @@ -128,3 +261,415 @@ def test_parse_invoke_connection_response_error_from_client(self): self.assertTrue(any(detail.get('error_from_client') == True for detail in exception.details)) self.assertEqual(exception.request_id, '12345') + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.construct_invoke_connection_request') + def test_invoke_construct_request_called(self, mock_construct, mock_get_credentials): + """Test that construct_invoke_connection_request is called with correct parameters.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_prepared_request = Mock(spec=requests.PreparedRequest) + mock_prepared_request.headers = {} + mock_construct.return_value = mock_prepared_request + + with patch('requests.Session.send') as mock_send: + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + self.connection.invoke(request) + + # Verify construct was called with connection_url from config + mock_construct.assert_called_once_with( + request, + VAULT_CONFIG["connection_url"], + self.mock_vault_client.get_logger() + ) + + +class TestGetDataFromContentType(unittest.TestCase): + """Tests for get_data_from_content_type covering all supported content types.""" + + DATA = {'key': 'value', 'num': 42} + + # ── JSON ────────────────────────────────────────────────────────────────── + def test_json_content_type_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.JSON.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + # ── URL-encoded ─────────────────────────────────────────────────────────── + def test_urlencoded_content_type_returns_encoded_string(self): + data, files = get_data_from_content_type({'k': 'v'}, ContentType.URLENCODED.value) + self.assertIn('k=v', data) + self.assertEqual(files, {}) + + def test_urlencoded_nested_dict(self): + payload = {'a': {'b': 'c'}} + data, files = get_data_from_content_type(payload, ContentType.URLENCODED.value) + self.assertIsInstance(data, str) + self.assertIn('c', data) + self.assertEqual(files, {}) + + # ── Form-data ───────────────────────────────────────────────────────────── + def test_formdata_content_type_returns_files_dict(self): + data, files = get_data_from_content_type({'f1': 'v1', 'f2': 'v2'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertEqual(files, {'f1': (None, 'v1'), 'f2': (None, 'v2')}) + + def test_formdata_converts_values_to_str(self): + data, files = get_data_from_content_type({'num': 99}, ContentType.FORMDATA.value) + self.assertEqual(files['num'], (None, '99')) + + def test_formdata_single_key(self): + data, files = get_data_from_content_type({'only': 'one'}, ContentType.FORMDATA.value) + self.assertIsNone(data) + self.assertIn('only', files) + + # ── XML ─────────────────────────────────────────────────────────────────── + def test_xml_text_xml_content_type_wraps_in_root(self): + data, files = get_data_from_content_type({'key': 'value'}, 'text/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertIn('', data) + self.assertEqual(files, {}) + + def test_xml_application_xml_content_type(self): + data, files = get_data_from_content_type({'key': 'value'}, 'application/xml') + self.assertIn('', data) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_content_type_enum_value(self): + data, files = get_data_from_content_type({'key': 'value'}, ContentType.XML.value) + self.assertIn('value', data) + self.assertEqual(files, {}) + + def test_xml_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw_string', 'text/xml') + self.assertEqual(data, 'raw_string') + self.assertEqual(files, {}) + + # ── HTML ────────────────────────────────────────────────────────────────── + def test_html_content_type_dict_returns_json_string(self): + data, files = get_data_from_content_type(self.DATA, ContentType.HTML.value) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_text_html_content_type(self): + data, files = get_data_from_content_type(self.DATA, 'text/html') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_html_non_dict_data_returns_str(self): + data, files = get_data_from_content_type('raw', ContentType.HTML.value) + self.assertEqual(data, 'raw') + self.assertEqual(files, {}) + + # ── None / unknown ──────────────────────────────────────────────────────── + def test_none_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, None) + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_falls_back_to_json(self): + data, files = get_data_from_content_type(self.DATA, 'application/octet-stream') + self.assertEqual(data, json.dumps(self.DATA)) + self.assertEqual(files, {}) + + def test_unknown_content_type_non_dict_returns_str(self): + data, files = get_data_from_content_type('blob', 'application/octet-stream') + self.assertEqual(data, 'blob') + self.assertEqual(files, {}) + + +class TestParseInvokeConnectionResponse(unittest.TestCase): + """Tests for parse_invoke_connection_response covering all success and error paths.""" + + def _make_response(self, status_code, content, headers=None, raise_http_error=False): + mock_resp = Mock(spec=requests.Response) + mock_resp.status_code = status_code + if isinstance(content, str): + mock_resp.content = content.encode('utf-8') + else: + mock_resp.content = content + mock_resp.headers = headers or {} + if raise_http_error: + mock_resp.raise_for_status.side_effect = requests.HTTPError() + else: + mock_resp.raise_for_status.return_value = None + return mock_resp + + # ── Success paths ───────────────────────────────────────────────────────── + def test_success_json_content_type_parses_body(self): + resp = self._make_response( + 200, + '{"result": "ok"}', + {'content-type': 'application/json', 'x-request-id': 'req-1'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'result': 'ok'}) + self.assertEqual(result.metadata.get('request_id'), 'req-1') + self.assertIsNone(result.errors) + + def test_success_plain_text_content_type_returns_string(self): + resp = self._make_response( + 200, + 'plain text response', + {'content-type': 'text/plain'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'plain text response') + + def test_success_no_content_type_tries_json_parse(self): + resp = self._make_response(200, '{"a": 1}', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'a': 1}) + + def test_success_no_content_type_invalid_json_returns_string(self): + resp = self._make_response(200, 'not json', {}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not json') + + def test_success_no_x_request_id_metadata_is_empty(self): + resp = self._make_response(200, '{}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.metadata, {}) + + def test_success_invalid_json_with_json_content_type_returns_raw_string(self): + resp = self._make_response( + 200, + 'not-json', + {'content-type': 'application/json'} + ) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, 'not-json') + + def test_success_bytes_content_decoded(self): + resp = self._make_response(200, b'{"x": 1}', {'content-type': 'application/json'}) + result = parse_invoke_connection_response(resp) + self.assertEqual(result.data, {'x': 1}) + + # ── Error paths — standard Skyflow format ──────────────────────────────── + def test_error_standard_skyflow_format_extracts_message(self): + body = json.dumps({'error': {'message': 'bad input', 'http_code': 400, 'http_status': 'BAD_REQUEST', 'grpc_code': 3, 'details': []}}) + resp = self._make_response(400, body, {'x-request-id': 'r1'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + e = ctx.exception + self.assertEqual(e.message, 'bad input') + self.assertEqual(e.http_code, 400) + self.assertEqual(e.request_id, 'r1') + self.assertEqual(e.http_status, 'BAD_REQUEST') + self.assertEqual(e.grpc_code, 3) + + def test_error_standard_format_falls_back_to_http_code_when_missing(self): + body = json.dumps({'error': {'message': 'oops'}}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.http_code, 500) + + def test_error_standard_format_falls_back_to_sdk_message_when_missing(self): + body = json.dumps({'error': {}}) + resp = self._make_response(503, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — string error value ───────────────────────────────────── + def test_error_string_error_value_used_as_message(self): + body = json.dumps({'error': 'gateway timed out'}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertEqual(ctx.exception.message, 'gateway timed out') + + def test_error_empty_string_error_value_falls_back_to_sdk_message(self): + body = json.dumps({'error': ''}) + resp = self._make_response(502, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-standard JSON ────────────────────────────────────── + def test_error_no_error_key_uses_sdk_message(self): + body = json.dumps({'message': 'something went wrong'}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_non_dict_json_body_uses_sdk_message(self): + body = json.dumps(['list', 'not', 'dict']) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + def test_error_numeric_error_value_uses_sdk_message(self): + body = json.dumps({'error': 12345}) + resp = self._make_response(500, body, {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(500) + self.assertEqual(ctx.exception.message, expected) + + # ── Error paths — non-JSON / empty body ────────────────────────────────── + def test_error_empty_body_uses_sdk_message(self): + resp = self._make_response(502, '', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + self.assertEqual(ctx.exception.http_code, 502) + + def test_error_html_body_uses_sdk_message(self): + resp = self._make_response(502, 'Bad Gateway', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(502) + self.assertEqual(ctx.exception.message, expected) + + def test_error_plain_text_body_uses_sdk_message(self): + resp = self._make_response(503, 'Service Unavailable', {}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + expected = SkyflowMessages.Error.API_ERROR.value.format(503) + self.assertEqual(ctx.exception.message, expected) + + # ── error-from-client header ────────────────────────────────────────────── + def test_error_from_client_true_appended_to_details(self): + body = json.dumps({'error': {'message': 'client error', 'http_code': 400, 'details': []}}) + resp = self._make_response(400, body, {'error-from-client': 'true', 'x-request-id': 'r2'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is True for d in ctx.exception.details)) + + def test_error_from_client_false_appended_to_details(self): + body = json.dumps({'error': {'message': 'server error', 'http_code': 500}}) + resp = self._make_response(500, body, {'error-from-client': 'false'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertTrue(any(d.get('error_from_client') is False for d in ctx.exception.details)) + + def test_error_from_client_initialises_details_when_none(self): + body = json.dumps({'error': {'message': 'err', 'http_code': 400}}) + resp = self._make_response(400, body, {'error-from-client': 'true'}, raise_http_error=True) + with self.assertRaises(SkyflowError) as ctx: + parse_invoke_connection_response(resp) + self.assertIsNotNone(ctx.exception.details) + self.assertTrue(len(ctx.exception.details) > 0) + + +class TestConstructInvokeConnectionRequest(unittest.TestCase): + """Tests for construct_invoke_connection_request covering method, body, headers, path/query params.""" + + BASE_URL = 'https://example.com/api' + LOGGER = Mock() + + def _make_request(self, method=RequestMethod.POST, body=None, headers=None, + path_params=None, query_params=None): + return InvokeConnectionRequest( + method=method, + body=body, + headers=headers, + path_params=path_params or {}, + query_params=query_params or {} + ) + + def test_post_with_json_body_prepares_request(self): + req = self._make_request(body={'k': 'v'}, headers={'Content-Type': 'application/json'}) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIn('k', prepared.body) + + def test_get_with_no_body(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'GET') + + def test_urlencoded_body_is_form_encoded(self): + req = self._make_request( + body={'field': 'val'}, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('field=val', prepared.body) + + def test_formdata_body_produces_multipart_request(self): + req = self._make_request( + body={'file_field': 'data'}, + headers={'Content-Type': 'multipart/form-data'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(prepared.method, 'POST') + self.assertIsNotNone(prepared.body) + + def test_xml_body_contains_xml_tags(self): + req = self._make_request( + body={'item': 'data'}, + headers={'Content-Type': 'text/xml'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('', prepared.body) + + def test_path_params_substituted_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + path_params={'id': '123'} + ) + url_with_placeholder = 'https://example.com/api/{id}/resource' + prepared = construct_invoke_connection_request(req, url_with_placeholder, self.LOGGER) + self.assertIn('123', prepared.url) + self.assertNotIn('{id}', prepared.url) + + def test_query_params_appear_in_url(self): + req = self._make_request( + method=RequestMethod.GET, + query_params={'page': '1', 'limit': '10'} + ) + prepared = construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertIn('page=1', prepared.url) + self.assertIn('limit=10', prepared.url) + + def test_invalid_headers_raises_skyflow_error(self): + req = InvokeConnectionRequest(method=RequestMethod.POST, headers='bad-headers') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + + def test_invalid_body_raises_skyflow_error(self): + req = InvokeConnectionRequest( + method=RequestMethod.POST, + body='not-a-dict', + headers={'Content-Type': 'application/json'} + ) + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + + def test_invalid_method_raises_skyflow_error(self): + req = InvokeConnectionRequest(method='INVALID_METHOD') + with self.assertRaises(SkyflowError) as ctx: + construct_invoke_connection_request(req, self.BASE_URL, self.LOGGER) + self.assertEqual(ctx.exception.message, SkyflowMessages.Error.INVALID_REQUEST_METHOD.value) + + def test_trailing_slash_stripped_from_url(self): + req = self._make_request(method=RequestMethod.GET) + prepared = construct_invoke_connection_request(req, self.BASE_URL + '/', self.LOGGER) + self.assertNotIn('//', prepared.url.replace('https://', '')) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index c2f9a861..f0f2aa87 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch, MagicMock import base64 import os +import tempfile from skyflow.error import SkyflowError from skyflow.generated.rest import WordCharacterCount from skyflow.utils import SkyflowMessages @@ -513,16 +514,12 @@ def test_get_detect_run_in_progress_status(self, mock_validate): self.vault_client.get_detect_file_api.return_value = files_api - # Execute - with patch.object(self.detect, "_Detect__parse_deidentify_file_response") as mock_parse: - result = self.detect.get_detect_run(req) + # Execute — IN_PROGRESS is returned directly without going through the parser + result = self.detect.get_detect_run(req) - # Verify IN_PROGRESS handling - mock_parse.assert_called_once() - args = mock_parse.call_args[0][0] - self.assertIsInstance(args, DeidentifyFileResponse) - self.assertEqual(args.status, 'IN_PROGRESS') - self.assertEqual(args.run_id, run_id) + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, 'IN_PROGRESS') + self.assertEqual(result.run_id, run_id) def test_get_transformations_with_shift_dates(self): @@ -711,3 +708,140 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) + + def test_poll_for_processed_file_exception(self): + files_api = Mock() + files_api.with_raw_response = files_api + files_api.get_run.side_effect = Exception("poll error") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect._Detect__poll_for_processed_file("runid", max_wait_time=5) + + def test_save_output_directory_not_exists(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=False): + self.detect._Detect__save_deidentify_file_response_output( + response, "/nonexistent_dir", "file.txt", "file" + ) + + def test_save_output_second_non_redacted_item(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output1 = Mock() + output1.processedFile = base64.b64encode(b"data1").decode() + output1.processedFileType = "redacted_file" + output1.processedFileExtension = "txt" + output2 = Mock() + output2.processedFile = base64.b64encode(b"data2").decode() + output2.processedFileType = "entities" + output2.processedFileExtension = "json" + response = Mock() + response.output = [output1, output2] + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "original.txt", "original" + ) + + def test_save_output_path_traversal_blocked(self): + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + call_count = [0] + + def fake_realpath(p): + call_count[0] += 1 + if call_count[0] == 1: + return "/safe_dir" + return "/outside/path" + + with patch("skyflow.vault.controller._detect.os.path.exists", return_value=True), \ + patch("skyflow.vault.controller._detect.os.path.realpath", side_effect=fake_realpath): + self.detect._Detect__save_deidentify_file_response_output( + response, "/safe_dir", "file.txt", "file" + ) + + def test_save_output_write_exception(self): + with tempfile.TemporaryDirectory() as tmp_dir: + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = "txt" + response = Mock() + response.output = [output] + with patch("skyflow.vault.controller._detect.base64.b64decode", + side_effect=Exception("decode error")), \ + self.assertRaises(Exception): + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "file.txt", "file" + ) + + def test_save_output_no_file_extension_uses_original_name(self): + """Branches 113->117 and 119->124: processed_file_extension is falsy — safe_ext stays None.""" + with tempfile.TemporaryDirectory() as tmp_dir: + output = Mock() + output.processedFile = base64.b64encode(b"data").decode() + output.processedFileType = "redacted_file" + output.processedFileExtension = None + output.processed_file_extension = None + response = Mock() + response.output = [output] + self.detect._Detect__save_deidentify_file_response_output( + response, tmp_dir, "original.txt", "original" + ) + + @patch("skyflow.vault.controller._detect.time.sleep", return_value=None) + def test_poll_unknown_status_then_success(self, mock_sleep): + """Branch 80->65: status is unknown, loops back, then returns SUCCESS.""" + files_api = Mock() + files_api.with_raw_response = files_api + self.vault_client.get_detect_file_api.return_value = files_api + + call_count = {"n": 0} + + def side_effect(*args, **kwargs): + call_count["n"] += 1 + r = Mock() + if call_count["n"] == 1: + r.status = "UNKNOWN_STATUS" + else: + r.status = "SUCCESS" + return Mock(data=r) + + files_api.get_run.side_effect = side_effect + result = self.detect._Detect__poll_for_processed_file("runid", max_wait_time=10) + self.assertEqual(result.status, "SUCCESS") + + def test_get_file_from_request_no_file_no_path_returns_none(self): + """Branch 285->exit: file_input has neither file nor file_path set.""" + req = DeidentifyFileRequest(file=FileInput(file=None, file_path=None)) + result = self.detect._Detect__get_file_from_request(req) + self.assertIsNone(result) + + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") + @patch("skyflow.vault.controller._detect.base64") + def test_deidentify_file_api_error_inside_try(self, mock_base64, mock_validate): + file_content = b"test content" + file_obj = Mock() + file_obj.read.return_value = file_content + file_obj.name = "test.txt" + mock_base64.b64encode.return_value.decode.return_value = "encoded" + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) + req.entities = [] + req.token_format = None + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = None + req.wait_time = None + files_api = Mock() + files_api.with_raw_response = files_api + files_api.deidentify_text.side_effect = Exception("API error inside try") + self.vault_client.get_detect_file_api.return_value = files_api + with self.assertRaises(Exception): + self.detect.deidentify_file(req) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 4e1a0dda..5acdf779 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -722,6 +722,76 @@ def test_upload_file_with_missing_file_source(self, mock_validate): self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_without_skyflow_id_successful(self, mock_validate): + """Test upload_file succeeds when skyflow_id is None (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/test.txt", + ) + mocked_open = mock_open_func(read_data=b"test file content") + mock_api_response = Mock() + mock_api_response.data = Mock(skyflow_id="generated-id-123") + records_api = self.vault_client.get_records_api.return_value + records_api.with_raw_response.upload_file_v_2.return_value = mock_api_response + with patch('builtins.open', mocked_open): + result = self.vault.upload_file(request) + mock_validate.assert_called_once_with(self.vault_client.get_logger(), request) + self.assertIsNone(request.skyflow_id) + self.assertEqual(result.skyflow_id, "generated-id-123") + self.assertIsNone(result.errors) + + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + @patch("skyflow.vault.controller._vault.open", mock_open(read_data=b"file_content"), create=True) + def test_upload_file_file_path_with_existing_file_name(self, mock_validate): + """Branch 73->76: file_name already set when file_path is present — skips basename call.""" + request = FileUploadRequest( + table=TABLE_NAME, + column_name="file_col", + file_path="/path/to/file.txt", + file_name="already_set.txt" + ) + mock_api = self.vault_client.get_records_api.return_value.with_raw_response + mock_response = Mock() + mock_response.data.skyflow_id = "sky123" + mock_api.upload_file_v_2.return_value = mock_response + + self.vault.upload_file(request) + mock_api.upload_file_v_2.assert_called_once() + + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_file_object_without_name_attr(self, mock_validate): + """Branch 84->89: file_object has no 'name' attr — __get_file_for_file_upload returns None.""" + file_obj = Mock(spec=[]) + request = FileUploadRequest( + table=TABLE_NAME, + column_name="file_col", + file_object=file_obj + ) + mock_api = self.vault_client.get_records_api.return_value.with_raw_response + mock_response = Mock() + mock_response.data.skyflow_id = "sky123" + mock_api.upload_file_v_2.return_value = mock_response + + self.vault.upload_file(request) + mock_api.upload_file_v_2.assert_called_once() + + @patch("skyflow.vault.controller._vault.validate_file_upload_request") + def test_upload_file_no_file_source_returns_none_file(self, mock_validate): + """Branch 84->89 (elif False): all file sources None — __get_file_for_file_upload returns None.""" + request = FileUploadRequest( + table=TABLE_NAME, + column_name="file_col" + ) + mock_api = self.vault_client.get_records_api.return_value.with_raw_response + mock_response = Mock() + mock_response.data.skyflow_id = "sky123" + mock_api.upload_file_v_2.return_value = mock_response + + self.vault.upload_file(request) + mock_api.upload_file_v_2.assert_called_once() + class TestFileUploadValidation(unittest.TestCase): def setUp(self): self.logger = Mock() @@ -874,3 +944,38 @@ def test_validate_missing_file_source(self): with self.assertRaises(SkyflowError) as error: validate_file_upload_request(self.logger, request) self.assertEqual(error.exception.message, SkyflowMessages.Error.MISSING_FILE_SOURCE.value) + + def test_validate_none_skyflow_id_is_allowed(self): + """Test that skyflow_id=None passes validation (it is optional).""" + request = FileUploadRequest( + table="test_table", + column_name="file_column", + base64="dGVzdCBmaWxlIGNvbnRlbnQ=", + file_name="test.txt" + ) + self.assertIsNone(request.skyflow_id) + validate_file_upload_request(self.logger, request) + + @patch('os.path.exists') + @patch('os.path.isfile') + def test_validate_file_path_without_skyflow_id(self, mock_isfile, mock_exists): + """Test validation succeeds with file_path and no skyflow_id.""" + mock_exists.return_value = True + mock_isfile.return_value = True + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_path="/path/to/file.txt" + ) + validate_file_upload_request(self.logger, request) + + def test_validate_file_object_without_skyflow_id(self): + """Test validation succeeds with file_object and no skyflow_id.""" + mock_file = Mock() + mock_file.seek = Mock() + request = FileUploadRequest( + table="test_table", + column_name="file_column", + file_object=mock_file + ) + validate_file_upload_request(self.logger, request) diff --git a/tests/vault/data/__init__.py b/tests/vault/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/data/test_responses.py b/tests/vault/data/test_responses.py new file mode 100644 index 00000000..ea9f2be1 --- /dev/null +++ b/tests/vault/data/test_responses.py @@ -0,0 +1,108 @@ +import unittest +from skyflow.vault.data._delete_response import DeleteResponse +from skyflow.vault.data._file_upload_response import FileUploadResponse +from skyflow.vault.data._get_response import GetResponse +from skyflow.vault.data._insert_response import InsertResponse +from skyflow.vault.data._query_response import QueryResponse +from skyflow.vault.data._update_response import UpdateResponse +from skyflow.vault.data._upload_file_request import UploadFileRequest + + +class TestDeleteResponse(unittest.TestCase): + def test_repr(self): + r = DeleteResponse(deleted_ids=["id1"], errors=None) + self.assertIn("DeleteResponse", repr(r)) + self.assertIn("id1", repr(r)) + + def test_str(self): + r = DeleteResponse(deleted_ids=["id1"], errors=None) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = DeleteResponse() + self.assertIsNone(r.deleted_ids) + self.assertIsNone(r.errors) + + +class TestFileUploadResponse(unittest.TestCase): + def test_repr(self): + r = FileUploadResponse(skyflow_id="sky123", errors=None) + self.assertIn("FileUploadResponse", repr(r)) + self.assertIn("sky123", repr(r)) + + def test_str(self): + r = FileUploadResponse(skyflow_id="sky123", errors=None) + self.assertEqual(str(r), repr(r)) + + +class TestGetResponse(unittest.TestCase): + def test_repr(self): + r = GetResponse(data=[{"field": "val"}], errors=None) + self.assertIn("GetResponse", repr(r)) + + def test_str(self): + r = GetResponse(data=[{"field": "val"}], errors=None) + self.assertEqual(str(r), repr(r)) + + def test_none_data_defaults_to_empty_list(self): + r = GetResponse(data=None) + self.assertEqual(r.data, []) + + def test_empty_data_not_replaced(self): + r = GetResponse(data={}) + self.assertEqual(r.data, {}) + + +class TestInsertResponse(unittest.TestCase): + def test_repr(self): + r = InsertResponse(inserted_fields=[{"skyflow_id": "id1"}], errors=None) + self.assertIn("InsertResponse", repr(r)) + + def test_str(self): + r = InsertResponse(inserted_fields=[{"skyflow_id": "id1"}]) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = InsertResponse() + self.assertIsNone(r.inserted_fields) + self.assertIsNone(r.errors) + + +class TestQueryResponse(unittest.TestCase): + def test_repr(self): + r = QueryResponse() + self.assertIn("QueryResponse", repr(r)) + + def test_str(self): + r = QueryResponse() + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = QueryResponse() + self.assertEqual(r.fields, []) + self.assertIsNone(r.errors) + + +class TestUpdateResponse(unittest.TestCase): + def test_repr(self): + r = UpdateResponse(updated_field={"skyflow_id": "id1"}, errors=None) + self.assertIn("UpdateResponse", repr(r)) + + def test_str(self): + r = UpdateResponse(updated_field={"skyflow_id": "id1"}) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = UpdateResponse() + self.assertIsNone(r.updated_field) + self.assertIsNone(r.errors) + + +class TestUploadFileRequest(unittest.TestCase): + def test_instantiation(self): + r = UploadFileRequest() + self.assertIsNotNone(r) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/detect/__init__.py b/tests/vault/detect/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/detect/test_models.py b/tests/vault/detect/test_models.py new file mode 100644 index 00000000..bec65297 --- /dev/null +++ b/tests/vault/detect/test_models.py @@ -0,0 +1,177 @@ +import unittest +import io +from skyflow.vault.detect._deidentify_text_response import DeidentifyTextResponse +from skyflow.vault.detect._reidentify_text_response import ReidentifyTextResponse +from skyflow.vault.detect._entity_info import EntityInfo +from skyflow.vault.detect._file_input import FileInput +from skyflow.vault.detect._text_index import TextIndex +from skyflow.vault.detect._date_transformation import DateTransformation +from skyflow.vault.detect._transformations import Transformations +from skyflow.vault.detect._file import File +from skyflow.utils.enums import DetectEntities + + +class TestTextIndex(unittest.TestCase): + def test_repr(self): + t = TextIndex(start=0, end=4) + self.assertIn("TextIndex", repr(t)) + self.assertIn("0", repr(t)) + + def test_str(self): + t = TextIndex(start=0, end=4) + self.assertEqual(str(t), repr(t)) + + def test_attributes(self): + t = TextIndex(start=5, end=10) + self.assertEqual(t.start, 5) + self.assertEqual(t.end, 10) + + +class TestEntityInfo(unittest.TestCase): + def setUp(self): + self.text_index = TextIndex(0, 4) + self.processed_index = TextIndex(0, 8) + + def test_repr(self): + e = EntityInfo( + token="TOKEN_1", value="John", + text_index=self.text_index, + processed_index=self.processed_index, + entity="NAME", scores={"confidence": 0.9} + ) + self.assertIn("EntityInfo", repr(e)) + self.assertIn("John", repr(e)) + + def test_str(self): + e = EntityInfo( + token="TOKEN_1", value="John", + text_index=self.text_index, + processed_index=self.processed_index, + entity="NAME", scores={} + ) + self.assertEqual(str(e), repr(e)) + + def test_attributes(self): + e = EntityInfo( + token="T", value="v", + text_index=self.text_index, + processed_index=self.processed_index, + entity="EMAIL", scores={"s": 1.0} + ) + self.assertEqual(e.token, "T") + self.assertEqual(e.entity, "EMAIL") + + +class TestDeidentifyTextResponse(unittest.TestCase): + def test_repr(self): + r = DeidentifyTextResponse( + processed_text="[TOKEN_1]", entities=[], word_count=1, char_count=9 + ) + self.assertIn("DeidentifyTextResponse", repr(r)) + + def test_str(self): + r = DeidentifyTextResponse( + processed_text="[TOKEN_1]", entities=[], word_count=1, char_count=9 + ) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = DeidentifyTextResponse( + processed_text="text", entities=[], word_count=1, char_count=4 + ) + self.assertIsNone(r.errors) + + +class TestReidentifyTextResponse(unittest.TestCase): + def test_repr(self): + r = ReidentifyTextResponse(processed_text="John lives in NYC") + self.assertIn("ReidentifyTextResponse", repr(r)) + + def test_str(self): + r = ReidentifyTextResponse(processed_text="John") + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = ReidentifyTextResponse(processed_text="text") + self.assertIsNone(r.errors) + + +class TestFileInput(unittest.TestCase): + def test_repr_with_file(self): + bio = io.BytesIO(b"data") + bio.name = "test.txt" + fi = FileInput(file=bio) + self.assertIn("FileInput", repr(fi)) + + def test_str(self): + fi = FileInput(file_path="/some/path.pdf") + self.assertEqual(str(fi), repr(fi)) + + def test_repr_no_file(self): + fi = FileInput() + self.assertIn("FileInput", repr(fi)) + self.assertIsNone(fi.file) + self.assertIsNone(fi.file_path) + + +class TestDateTransformation(unittest.TestCase): + def test_instantiation(self): + dt = DateTransformation( + max_days=30, min_days=1, + entities=[DetectEntities.DATE] + ) + self.assertEqual(dt.max, 30) + self.assertEqual(dt.min, 1) + self.assertEqual(dt.entities, [DetectEntities.DATE]) + + +class TestTransformations(unittest.TestCase): + def test_instantiation(self): + dt = DateTransformation(max_days=30, min_days=1, entities=[DetectEntities.DATE]) + t = Transformations(shift_dates=dt) + self.assertEqual(t.shift_dates, dt) + + +class TestFile(unittest.TestCase): + def test_properties_with_file(self): + bio = io.BytesIO(b"hello") + bio.name = "test.txt" + f = File(file=bio) + self.assertEqual(f.name, "test.txt") + self.assertEqual(f.size, 5) + self.assertIsNotNone(f.type) + self.assertIsNotNone(f.last_modified) + + def test_properties_without_file(self): + f = File() + self.assertIsNone(f.name) + self.assertIsNone(f.size) + self.assertIsNone(f.type) + self.assertIsNone(f.last_modified) + + def test_seek_without_file(self): + f = File() + result = f.seek(0) + self.assertIsNone(result) + + def test_read_without_file(self): + f = File() + result = f.read() + self.assertIsNone(result) + + def test_seek_with_file(self): + bio = io.BytesIO(b"hello") + bio.name = "t.txt" + f = File(file=bio) + f.seek(0) + self.assertEqual(f.read(), b"hello") + + def test_repr(self): + bio = io.BytesIO(b"hi") + bio.name = "t.txt" + f = File(file=bio) + self.assertIn("File", repr(f)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/vault/tokens/__init__.py b/tests/vault/tokens/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/vault/tokens/test_responses.py b/tests/vault/tokens/test_responses.py new file mode 100644 index 00000000..62f217de --- /dev/null +++ b/tests/vault/tokens/test_responses.py @@ -0,0 +1,38 @@ +import unittest +from skyflow.vault.tokens._detokenize_response import DetokenizeResponse +from skyflow.vault.tokens._tokenize_response import TokenizeResponse + + +class TestDetokenizeResponse(unittest.TestCase): + def test_repr(self): + r = DetokenizeResponse(detokenized_fields=[{"token": "t1", "value": "v1"}], errors=None) + self.assertIn("DetokenizeResponse", repr(r)) + self.assertIn("t1", repr(r)) + + def test_str(self): + r = DetokenizeResponse(detokenized_fields=[{"token": "t1"}]) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = DetokenizeResponse() + self.assertIsNone(r.detokenized_fields) + self.assertIsNone(r.errors) + + +class TestTokenizeResponse(unittest.TestCase): + def test_repr(self): + r = TokenizeResponse(tokenized_fields=[{"value": "val", "token": "tok"}], errors=None) + self.assertIn("TokenizeResponse", repr(r)) + + def test_str(self): + r = TokenizeResponse(tokenized_fields=[{"token": "tok"}]) + self.assertEqual(str(r), repr(r)) + + def test_defaults(self): + r = TokenizeResponse() + self.assertIsNone(r.tokenized_fields) + self.assertIsNone(r.errors) + + +if __name__ == "__main__": + unittest.main()