|
1 | 1 | import base64 |
| 2 | +import struct |
| 3 | +from enum import Enum |
| 4 | +from typing import List, Optional |
| 5 | +from dataclasses import dataclass |
| 6 | + |
2 | 7 | from aiortc.rtp import RtpPacket |
3 | 8 |
|
4 | | -class MsSrtpCrypto: |
5 | | - def __init__(self, master_key: str): |
6 | | - try: |
7 | | - master_key = base64.b64decode(master_key) |
8 | | - except Exception: |
9 | | - raise ValueError('Master key is not base64-decodable') |
10 | | - |
11 | | - self.master_key = master_key |
| 9 | +from cryptography.hazmat.backends import default_backend |
| 10 | +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes |
| 11 | +from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
| 12 | + |
| 13 | +class TransformDirection(Enum): |
| 14 | + Encrypt = 0 |
| 15 | + Decrypt = 1 |
| 16 | + |
| 17 | +class SrtpMasterKeys: |
| 18 | + MASTER_KEY_SIZE = 30 |
| 19 | + DUMMY_KEY = ( |
| 20 | + b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F' |
| 21 | + b'\x10\x11\x12\x13' |
| 22 | + ) |
| 23 | + |
| 24 | + def __init__(self, master_key: bytes): |
| 25 | + assert len(master_key) == SrtpMasterKeys.MASTER_KEY_SIZE |
| 26 | + self.key1_buf = master_key[:0x10] |
| 27 | + self.key1_len = len(self.key1_buf) |
| 28 | + self.key1_counter = 0 |
| 29 | + |
| 30 | + self.key2_buf = master_key[0x10:] |
| 31 | + self.key2_len = len(self.key2_buf) |
| 32 | + self.key2_counter = 0 |
| 33 | + |
| 34 | + @classmethod |
| 35 | + def from_base64(cls, master_key_b64: str): |
| 36 | + return cls(base64.b64decode(master_key_b64)) |
| 37 | + |
| 38 | + @classmethod |
| 39 | + def null_keys(cls): |
| 40 | + return cls(SrtpMasterKeys.MASTER_KEY_SIZE * b'\x00') |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def dummy_keys(cls): |
| 44 | + dummy_key = SrtpMasterKeys.DUMMY_KEY[:0x10] + SrtpMasterKeys.DUMMY_KEY[:0x0E] |
| 45 | + return cls(dummy_key) |
| 46 | + |
| 47 | +@dataclass |
| 48 | +class SrtpSessionKey: |
| 49 | + buf: bytes |
| 50 | + len: int |
| 51 | + tag: bytes |
| 52 | + |
| 53 | + def __init__(self, key: bytes): |
| 54 | + self.buf = key |
| 55 | + self.len = len(key) |
| 56 | + self.tag = 1 |
| 57 | + |
| 58 | +class SrtpSessionKeys: |
| 59 | + def __init__(self, session_keys: List[SrtpSessionKey]): |
| 60 | + assert len(session_keys) == 3 |
| 61 | + self.session_key_1 = session_keys[0] |
| 62 | + self.session_key_2 = session_keys[1] |
| 63 | + self.session_key_3 = session_keys[2] |
| 64 | + |
| 65 | + @property |
| 66 | + def aes_gcm_key(self) -> bytes: |
| 67 | + return self.session_key_1.buf |
| 68 | + |
| 69 | + @property |
| 70 | + def nonce_key(self) -> bytes: |
| 71 | + return self.session_key_3.buf |
| 72 | + |
| 73 | +class SrtpContext: |
| 74 | + _backend = default_backend() |
| 75 | + |
| 76 | + def __init__(self, master_keys: SrtpMasterKeys): |
| 77 | + """ |
| 78 | + MS-SRTP context |
| 79 | + """ |
| 80 | + self.master_keys = master_keys |
| 81 | + self.session_keys = SrtpContext._derive_session_keys( |
| 82 | + self.master_keys.key1_buf, self.master_keys.key2_buf |
| 83 | + ) |
| 84 | + |
| 85 | + # Set-up GCM crypto instances |
| 86 | + self.decryptor_ctx = SrtpContext._init_gcm_cryptor(self.session_keys.aes_gcm_key) |
| 87 | + self.decryptor_ctx = SrtpContext._init_gcm_cryptor(self.session_keys.aes_gcm_key) |
| 88 | + |
| 89 | + @classmethod |
| 90 | + def from_base64(cls, master_key_b64: str): |
| 91 | + return cls( |
| 92 | + SrtpMasterKeys.from_base64(master_key_b64) |
| 93 | + ) |
12 | 94 |
|
13 | | - def decrypt(self, rtp_data: RtpPacket) -> RtpPacket: |
14 | | - raise NotImplementedError('Decryption not implemented') |
| 95 | + @classmethod |
| 96 | + def from_bytes(cls, master_key: bytes): |
| 97 | + return cls( |
| 98 | + SrtpMasterKeys(master_key) |
| 99 | + ) |
| 100 | + |
| 101 | + @staticmethod |
| 102 | + def _derive_single_key(input_key: bytes, bitmask: int = 0) -> bytes: |
| 103 | + keysize = len(input_key) |
| 104 | + keyout = bytearray(b'\x00' * 16) |
| 105 | + |
| 106 | + if keysize >= 14: |
| 107 | + keysize = 14 |
| 108 | + |
| 109 | + if keysize: |
| 110 | + keyout[13] = input_key[keysize - 1] |
| 111 | + if keysize != 1: |
| 112 | + keyout[12] = input_key[keysize - 2] |
| 113 | + if keysize >= 3: |
| 114 | + key_index = 0 |
| 115 | + for _ in range(2, keysize): |
| 116 | + keyout[key_index + 11] = input_key[key_index + keysize - 3] |
| 117 | + key_index = key_index - 1 |
| 118 | + |
| 119 | + if keysize <= 13: |
| 120 | + null_count = 14 - keysize |
| 121 | + for i in range(0, null_count): |
| 122 | + keyout[i] = 0 |
| 123 | + |
| 124 | + for index in range(14, 16): |
| 125 | + keyout[index] = 0 |
| 126 | + |
| 127 | + if bitmask: |
| 128 | + len_before_xor = len(keyout) |
| 129 | + value_to_xor = struct.unpack_from('<I', keyout, 4)[0] |
| 130 | + value_to_xor ^= bitmask |
| 131 | + keyout = keyout[:4] + struct.pack('<I', value_to_xor) + keyout[8:] |
| 132 | + assert len(keyout) == len_before_xor |
| 133 | + return keyout |
| 134 | + |
| 135 | + @staticmethod |
| 136 | + def _crypt_ctr_oneshot(key: bytes, iv: bytes, plaintext: bytes, max_bytes: Optional[int] = None): |
| 137 | + """ |
| 138 | + Encrypt data with AES-CTR (one-shot) |
| 139 | + """ |
| 140 | + cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) |
| 141 | + encryptor = cipher.encryptor() |
| 142 | + cipher_out = encryptor.update(plaintext) + encryptor.finalize() |
| 143 | + if max_bytes: |
| 144 | + # Trim to desired output |
| 145 | + cipher_out = cipher_out[:max_bytes] |
| 146 | + return cipher_out |
| 147 | + |
| 148 | + @staticmethod |
| 149 | + def _derive_session_keys(key1: bytes, key2: bytes) -> SrtpSessionKeys: |
| 150 | + session1 = SrtpContext._derive_single_key(key2) |
| 151 | + session2 = SrtpContext._derive_single_key(key2, 0x1000000) |
| 152 | + session3 = SrtpContext._derive_single_key(key2, 0x2000000) |
| 153 | + |
| 154 | + session1 = SrtpContext._crypt_ctr_oneshot(key1, session1, b'\x00' * 16) |
| 155 | + session2 = SrtpContext._crypt_ctr_oneshot(key1, session2, b'\x00' * 16) |
| 156 | + session3 = SrtpContext._crypt_ctr_oneshot(key1, session3, b'\x00' * 16, max_bytes=14) |
| 157 | + |
| 158 | + return SrtpSessionKeys([ |
| 159 | + SrtpSessionKey(session1), |
| 160 | + SrtpSessionKey(session2), |
| 161 | + SrtpSessionKey(session3) |
| 162 | + ]) |
| 163 | + |
| 164 | + @staticmethod |
| 165 | + def _init_gcm_cryptor(key: bytes) -> AESGCM: |
| 166 | + return AESGCM(key) |
| 167 | + |
| 168 | + @staticmethod |
| 169 | + def _decrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes: |
| 170 | + return ctx.decrypt(nonce, data, aad) |
| 171 | + |
| 172 | + @staticmethod |
| 173 | + def _encrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes: |
| 174 | + return ctx.encrypt(nonce, data, aad) |
| 175 | + |
| 176 | + def _get_transformed_nonce(self, transform_direction: TransformDirection) -> bytes: |
| 177 | + # Skip first 2 bytes of Nonce key |
| 178 | + nonce = bytearray(self.session_keys.nonce_key[2:]) |
| 179 | + # TODO: Implement transform logic |
| 180 | + # FIXME: Just tranforming the Nonce to a known value for |
| 181 | + # our single test packet |
| 182 | + nonce[-1] = nonce[-1] + 1 |
| 183 | + |
| 184 | + return nonce |
15 | 185 |
|
16 | | - def decrypt_raw(self, data: bytes) -> RtpPacket: |
17 | | - packet = RtpPacket.parse(data) |
18 | | - return self.decrypt(packet) |
| 186 | + def decrypt(self, data: bytes, aad: bytes) -> bytes: |
| 187 | + nonce = self._get_transformed_nonce(TransformDirection.Decrypt) |
| 188 | + return SrtpContext._decrypt(self.decryptor_ctx, nonce, data, aad) |
19 | 189 |
|
20 | | - def encrypt(self, rtp_data: RtpPacket) -> RtpPacket: |
21 | | - raise NotImplementedError('Encryption not implemented') |
| 190 | + def encrypt(self, data: bytes, aad: bytes) -> RtpPacket: |
| 191 | + nonce = self._get_transformed_nonce(TransformDirection.Encrypt) |
| 192 | + return SrtpContext._encrypt(self.decryptor_ctx, nonce, data, aad) |
0 commit comments