|
| 1 | +""" |
| 2 | +RSA Key Utility Module |
| 3 | +
|
| 4 | +Provides utilities for managing RSA private keys used for encrypting secrets. |
| 5 | +Functions include: |
| 6 | +- Loading private keys from PEM format |
| 7 | +- Saving private keys with duplicate detection |
| 8 | +- Getting public keys from private keys |
| 9 | +- Listing existing keys in a folder |
| 10 | +- Checking for duplicate keys |
| 11 | +""" |
| 12 | + |
| 13 | +from pathlib import Path |
| 14 | +from typing import Optional, Tuple, List, Dict |
| 15 | +from cryptography.hazmat.primitives import serialization |
| 16 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 17 | +from cryptography.hazmat.backends import default_backend |
| 18 | + |
| 19 | + |
| 20 | +def load_private_key_from_pem(pem_content: str) -> Optional[rsa.RSAPrivateKey]: |
| 21 | + """ |
| 22 | + Load an RSA private key from PEM format content. |
| 23 | + |
| 24 | + Args: |
| 25 | + pem_content: String containing the PEM-encoded private key |
| 26 | + |
| 27 | + Returns: |
| 28 | + RSAPrivateKey object if successful, None otherwise |
| 29 | + """ |
| 30 | + try: |
| 31 | + private_key = serialization.load_pem_private_key( |
| 32 | + pem_content.encode('utf-8'), |
| 33 | + password=None, |
| 34 | + backend=default_backend() |
| 35 | + ) |
| 36 | + # Verify it's an RSA key |
| 37 | + if not isinstance(private_key, rsa.RSAPrivateKey): |
| 38 | + return None |
| 39 | + return private_key |
| 40 | + except Exception as e: |
| 41 | + print(f"Error loading private key from PEM: {e}") |
| 42 | + return None |
| 43 | + |
| 44 | + |
| 45 | +def get_public_key_pem(private_key: rsa.RSAPrivateKey) -> str: |
| 46 | + """ |
| 47 | + Extract the public key from a private key and return it as PEM string. |
| 48 | + |
| 49 | + Args: |
| 50 | + private_key: RSAPrivateKey object |
| 51 | + |
| 52 | + Returns: |
| 53 | + PEM-encoded public key as string |
| 54 | + """ |
| 55 | + public_key = private_key.public_key() |
| 56 | + public_key_pem = public_key.public_bytes( |
| 57 | + encoding=serialization.Encoding.PEM, |
| 58 | + format=serialization.PublicFormat.SubjectPublicKeyInfo |
| 59 | + ).decode('utf-8') |
| 60 | + return public_key_pem |
| 61 | + |
| 62 | + |
| 63 | +def check_duplicate_key(new_private_key: rsa.RSAPrivateKey, key_folder: Path) -> Optional[str]: |
| 64 | + """ |
| 65 | + Check if a private key already exists in the folder by comparing public keys. |
| 66 | + |
| 67 | + Args: |
| 68 | + new_private_key: The private key to check |
| 69 | + key_folder: Path to the folder containing existing keys |
| 70 | + |
| 71 | + Returns: |
| 72 | + Filename of the duplicate key if found, None otherwise |
| 73 | + """ |
| 74 | + if not key_folder.exists(): |
| 75 | + return None |
| 76 | + |
| 77 | + new_public_key_pem = get_public_key_pem(new_private_key) |
| 78 | + |
| 79 | + pem_files = list(key_folder.glob("*.pem")) |
| 80 | + for pem_file in pem_files: |
| 81 | + try: |
| 82 | + with open(pem_file, 'rb') as f: |
| 83 | + existing_private_key = serialization.load_pem_private_key( |
| 84 | + f.read(), |
| 85 | + password=None, |
| 86 | + backend=default_backend() |
| 87 | + ) |
| 88 | + # Verify it's an RSA key |
| 89 | + if not isinstance(existing_private_key, rsa.RSAPrivateKey): |
| 90 | + continue |
| 91 | + existing_public_key_pem = get_public_key_pem(existing_private_key) |
| 92 | + |
| 93 | + if existing_public_key_pem == new_public_key_pem: |
| 94 | + return pem_file.name |
| 95 | + except Exception as e: |
| 96 | + print(f"Warning: Could not check key {pem_file.name}: {e}") |
| 97 | + continue |
| 98 | + |
| 99 | + return None |
| 100 | + |
| 101 | + |
| 102 | +def save_private_key( |
| 103 | + private_key: rsa.RSAPrivateKey, |
| 104 | + key_folder: Path, |
| 105 | + filename: str, |
| 106 | + format_type: str = "pkcs8", |
| 107 | + check_duplicate: bool = True |
| 108 | +) -> Tuple[bool, Optional[str], Optional[Path]]: |
| 109 | + """ |
| 110 | + Save an RSA private key to a file with optional duplicate checking. |
| 111 | + |
| 112 | + Args: |
| 113 | + private_key: The RSAPrivateKey object to save |
| 114 | + key_folder: Path to the folder where the key should be saved |
| 115 | + filename: Name of the file (should end in .pem) |
| 116 | + format_type: Format for saving ("pkcs8" or "traditional_openssl") |
| 117 | + check_duplicate: Whether to check for duplicate keys before saving |
| 118 | + |
| 119 | + Returns: |
| 120 | + Tuple of (success: bool, duplicate_filename: Optional[str], saved_path: Optional[Path]) |
| 121 | + - If successful: (True, None, path_to_saved_file) |
| 122 | + - If duplicate found: (False, duplicate_filename, None) |
| 123 | + - If error: (False, None, None) |
| 124 | + """ |
| 125 | + try: |
| 126 | + # Ensure the folder exists |
| 127 | + key_folder.mkdir(parents=True, exist_ok=True) |
| 128 | + |
| 129 | + # Check for duplicates if requested |
| 130 | + if check_duplicate: |
| 131 | + duplicate = check_duplicate_key(private_key, key_folder) |
| 132 | + if duplicate: |
| 133 | + return False, duplicate, None |
| 134 | + |
| 135 | + # Determine the format |
| 136 | + if format_type == "pkcs8": |
| 137 | + format_obj = serialization.PrivateFormat.PKCS8 |
| 138 | + elif format_type == "traditional_openssl": |
| 139 | + format_obj = serialization.PrivateFormat.TraditionalOpenSSL |
| 140 | + else: |
| 141 | + print(f"Warning: Unknown format type '{format_type}', using PKCS8") |
| 142 | + format_obj = serialization.PrivateFormat.PKCS8 |
| 143 | + |
| 144 | + # Save the private key |
| 145 | + key_path = key_folder / filename |
| 146 | + with open(key_path, 'wb') as f: |
| 147 | + f.write(private_key.private_bytes( |
| 148 | + encoding=serialization.Encoding.PEM, |
| 149 | + format=format_obj, |
| 150 | + encryption_algorithm=serialization.NoEncryption() |
| 151 | + )) |
| 152 | + |
| 153 | + return True, None, key_path |
| 154 | + |
| 155 | + except Exception as e: |
| 156 | + print(f"Error saving private key: {e}") |
| 157 | + return False, None, None |
| 158 | + |
| 159 | + |
| 160 | +def list_existing_keys(key_folder: Path) -> List[Dict[str, str]]: |
| 161 | + """ |
| 162 | + List all existing RSA private keys in a folder and extract their public keys. |
| 163 | + |
| 164 | + Args: |
| 165 | + key_folder: Path to the folder containing .pem files |
| 166 | + |
| 167 | + Returns: |
| 168 | + List of dictionaries with keys: |
| 169 | + - 'filename': Name of the key file |
| 170 | + - 'path': Full path to the key file |
| 171 | + - 'public_key': PEM-encoded public key (if successfully loaded) |
| 172 | + - 'error': Error message (if failed to load) |
| 173 | + """ |
| 174 | + if not key_folder.exists(): |
| 175 | + return [] |
| 176 | + |
| 177 | + keys_info = [] |
| 178 | + pem_files = sorted(key_folder.glob("*.pem")) |
| 179 | + |
| 180 | + for pem_file in pem_files: |
| 181 | + key_info = { |
| 182 | + 'filename': pem_file.name, |
| 183 | + 'path': str(pem_file) |
| 184 | + } |
| 185 | + |
| 186 | + try: |
| 187 | + with open(pem_file, 'rb') as f: |
| 188 | + private_key = serialization.load_pem_private_key( |
| 189 | + f.read(), |
| 190 | + password=None, |
| 191 | + backend=default_backend() |
| 192 | + ) |
| 193 | + |
| 194 | + # Verify it's an RSA key |
| 195 | + if not isinstance(private_key, rsa.RSAPrivateKey): |
| 196 | + key_info['error'] = "Not an RSA private key" |
| 197 | + keys_info.append(key_info) |
| 198 | + continue |
| 199 | + |
| 200 | + # Extract public key |
| 201 | + public_key_pem = get_public_key_pem(private_key) |
| 202 | + key_info['public_key'] = public_key_pem |
| 203 | + |
| 204 | + except Exception as e: |
| 205 | + key_info['error'] = str(e) |
| 206 | + |
| 207 | + keys_info.append(key_info) |
| 208 | + |
| 209 | + return keys_info |
0 commit comments