|
9 | 9 | from ..const import COSE_KEY_OPERATION_VALUES |
10 | 10 | from ..cose_key_interface import COSEKeyInterface |
11 | 11 | from ..exceptions import DecodeError, EncodeError, VerifyError |
| 12 | +from .non_aead import AESCBC, AESCTR |
12 | 13 |
|
13 | 14 | _CWT_DEFAULT_KEY_SIZE_HMAC256 = 32 # bytes |
14 | 15 | _CWT_DEFAULT_KEY_SIZE_HMAC384 = 48 |
15 | 16 | _CWT_DEFAULT_KEY_SIZE_HMAC512 = 64 |
16 | 17 | _CWT_NONCE_SIZE_AESGCM = 12 |
17 | 18 | _CWT_NONCE_SIZE_CHACHA20_POLY1305 = 12 |
| 19 | +_CWT_NONCE_SIZE_AES = 16 |
18 | 20 |
|
19 | 21 |
|
20 | 22 | class SymmetricKey(COSEKeyInterface): |
@@ -353,3 +355,107 @@ def unwrap_key(self, wrapped_key: bytes) -> bytes: |
353 | 355 | return aes_key_unwrap(self._key, wrapped_key) |
354 | 356 | except Exception as err: |
355 | 357 | raise DecodeError("Failed to unwrap key.") from err |
| 358 | + |
| 359 | + |
| 360 | +class AESCTRKey(ContentEncryptionKey): |
| 361 | + """ """ |
| 362 | + |
| 363 | + def __init__(self, params: Dict[int, Any]): |
| 364 | + """ """ |
| 365 | + super().__init__(params) |
| 366 | + |
| 367 | + self._cipher: AESCTR |
| 368 | + |
| 369 | + # Validate alg. |
| 370 | + if self._alg == -65534: # A128CTR |
| 371 | + if not self._key: |
| 372 | + self._key = AESCTR.generate_key(bit_length=128) |
| 373 | + if len(self._key) != 16: |
| 374 | + raise ValueError("The length of A128CTR key should be 16 bytes.") |
| 375 | + elif self._alg == -65533: # A192CTR |
| 376 | + if not self._key: |
| 377 | + self._key = AESCTR.generate_key(bit_length=192) |
| 378 | + if len(self._key) != 24: |
| 379 | + raise ValueError("The length of A192CTR key should be 24 bytes.") |
| 380 | + elif self._alg == -65532: # A256CTR |
| 381 | + if not self._key: |
| 382 | + self._key = AESCTR.generate_key(bit_length=256) |
| 383 | + if len(self._key) != 32: |
| 384 | + raise ValueError("The length of A256CTR key should be 32 bytes.") |
| 385 | + else: |
| 386 | + raise ValueError(f"Unsupported or unknown alg(3) for AES CTR: {self._alg}.") |
| 387 | + |
| 388 | + self._cipher = AESCTR(self._key) |
| 389 | + return |
| 390 | + |
| 391 | + def generate_nonce(self): |
| 392 | + return token_bytes(_CWT_NONCE_SIZE_AES) |
| 393 | + |
| 394 | + def encrypt(self, msg: bytes, nonce: bytes, aad: Optional[bytes] = None) -> bytes: |
| 395 | + """ """ |
| 396 | + try: |
| 397 | + return self._cipher.encrypt(nonce, msg) |
| 398 | + except Exception as err: |
| 399 | + raise EncodeError("Failed to encrypt.") from err |
| 400 | + |
| 401 | + def decrypt(self, msg: bytes, nonce: bytes, aad: Optional[bytes] = None) -> bytes: |
| 402 | + """ """ |
| 403 | + try: |
| 404 | + return self._cipher.decrypt(nonce, msg) |
| 405 | + except Exception as err: |
| 406 | + raise DecodeError("Failed to decrypt.") from err |
| 407 | + |
| 408 | + |
| 409 | +class AESCBCKey(ContentEncryptionKey): |
| 410 | + """ """ |
| 411 | + |
| 412 | + def __init__(self, params: Dict[int, Any]): |
| 413 | + """ """ |
| 414 | + super().__init__(params) |
| 415 | + |
| 416 | + self._cipher: AESCBC |
| 417 | + |
| 418 | + # Validate alg. |
| 419 | + if self._alg == -65531: # A128CBC |
| 420 | + if not self._key: |
| 421 | + self._key = AESCBC.generate_key(bit_length=128) |
| 422 | + if len(self._key) != 16: |
| 423 | + raise ValueError("The length of A128CBC key should be 16 bytes.") |
| 424 | + elif self._alg == -65530: # A192CBC |
| 425 | + if not self._key: |
| 426 | + self._key = AESCBC.generate_key(bit_length=192) |
| 427 | + if len(self._key) != 24: |
| 428 | + raise ValueError("The length of A192CBC key should be 24 bytes.") |
| 429 | + elif self._alg == -65529: # A256CBC |
| 430 | + if not self._key: |
| 431 | + self._key = AESCBC.generate_key(bit_length=256) |
| 432 | + if len(self._key) != 32: |
| 433 | + raise ValueError("The length of A256CBC key should be 32 bytes.") |
| 434 | + else: |
| 435 | + raise ValueError(f"Unsupported or unknown alg(3) for AES CBC: {self._alg}.") |
| 436 | + |
| 437 | + self._cipher = AESCBC(self._key) |
| 438 | + return |
| 439 | + |
| 440 | + def generate_nonce(self): |
| 441 | + return token_bytes(_CWT_NONCE_SIZE_AES) |
| 442 | + |
| 443 | + def encrypt(self, msg: bytes, nonce: bytes, aad: Optional[bytes] = None) -> bytes: |
| 444 | + """ """ |
| 445 | + try: |
| 446 | + # Add padding (see RFC 9459 and 5652) |
| 447 | + padding_value = 16 - len(msg) % 16 |
| 448 | + padding_length = 16 if padding_value == 0 else padding_value |
| 449 | + padding = (padding_value).to_bytes(1, "big") * padding_length |
| 450 | + return self._cipher.encrypt(nonce, msg + padding) |
| 451 | + except Exception as err: |
| 452 | + raise EncodeError("Failed to encrypt.") from err |
| 453 | + |
| 454 | + def decrypt(self, msg: bytes, nonce: bytes, aad: Optional[bytes] = None) -> bytes: |
| 455 | + """ """ |
| 456 | + try: |
| 457 | + decrypted = self._cipher.decrypt(nonce, msg) |
| 458 | + # Remove padding (see RFC 9459 and 5652) |
| 459 | + return decrypted[0 : -(decrypted[-1])] |
| 460 | + except Exception as err: |
| 461 | + raise DecodeError("Failed to decrypt.") from err |
0 commit comments