Skip to content

Commit 99add1e

Browse files
committed
Crypto: Initial implementation
1 parent dfc09c2 commit 99add1e

3 files changed

Lines changed: 205 additions & 19 deletions

File tree

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,5 @@ def srtp_key() -> str:
5959
return 'RdHzuLLVGuO1aHILIEVJ1UzR7RWVioepmpy+9SRf'
6060

6161
@pytest.fixture(scope='session')
62-
def crypto_context(srtp_key: str) -> srtp_crypto.MsSrtpCrypto:
63-
return srtp_crypto.MsSrtpCrypto(srtp_key)
62+
def crypto_context(srtp_key: str) -> srtp_crypto.SrtpContext:
63+
return srtp_crypto.SrtpContext.from_base64(srtp_key)

tests/test_crypto.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
11
from xcloud.protocol import srtp_crypto
22

3-
def test_decrypt(test_data: dict, crypto_context: srtp_crypto.MsSrtpCrypto):
3+
def test_decrypt(test_data: dict, crypto_context: srtp_crypto.SrtpContext):
44
rtp_packet_raw = test_data['rtp_connection_probing.bin']
5-
plaintext = crypto_context.decrypt_raw(rtp_packet_raw)
5+
6+
rtp_header, rtp_body = rtp_packet_raw[:12], rtp_packet_raw[12:]
7+
plaintext = crypto_context.decrypt(rtp_body, aad=rtp_header)
68

79
print(plaintext)
810
assert plaintext is not None
11+
12+
def test_init_master_keys(srtp_key: str):
13+
from_base64 = srtp_crypto.SrtpMasterKeys.from_base64(srtp_key)
14+
null_keys = srtp_crypto.SrtpMasterKeys.null_keys()
15+
dummy_keys = srtp_crypto.SrtpMasterKeys.dummy_keys()
16+
17+
assert len(from_base64.key1_buf) == 0x10
18+
assert from_base64.key1_len == 0x10
19+
assert len(from_base64.key2_buf) == 0x0E
20+
assert from_base64.key2_len == 0x0E
21+
22+
assert null_keys is not None
23+
assert dummy_keys is not None

xcloud/protocol/srtp_crypto.py

Lines changed: 186 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,192 @@
11
import base64
2+
import struct
3+
from enum import Enum
4+
from typing import List, Optional
5+
from dataclasses import dataclass
6+
27
from aiortc.rtp import RtpPacket
38

4-
class MsSrtpCrypto:
5-
def __init__(self, master_key: str):
6-
try:
7-
master_key = base64.b64decode(master_key)
8-
except Exception:
9-
raise ValueError('Master key is not base64-decodable')
10-
11-
self.master_key = master_key
9+
from cryptography.hazmat.backends import default_backend
10+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
11+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
12+
13+
class TransformDirection(Enum):
14+
Encrypt = 0
15+
Decrypt = 1
16+
17+
class SrtpMasterKeys:
18+
MASTER_KEY_SIZE = 30
19+
DUMMY_KEY = (
20+
b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F'
21+
b'\x10\x11\x12\x13'
22+
)
23+
24+
def __init__(self, master_key: bytes):
25+
assert len(master_key) == SrtpMasterKeys.MASTER_KEY_SIZE
26+
self.key1_buf = master_key[:0x10]
27+
self.key1_len = len(self.key1_buf)
28+
self.key1_counter = 0
29+
30+
self.key2_buf = master_key[0x10:]
31+
self.key2_len = len(self.key2_buf)
32+
self.key2_counter = 0
33+
34+
@classmethod
35+
def from_base64(cls, master_key_b64: str):
36+
return cls(base64.b64decode(master_key_b64))
37+
38+
@classmethod
39+
def null_keys(cls):
40+
return cls(SrtpMasterKeys.MASTER_KEY_SIZE * b'\x00')
41+
42+
@classmethod
43+
def dummy_keys(cls):
44+
dummy_key = SrtpMasterKeys.DUMMY_KEY[:0x10] + SrtpMasterKeys.DUMMY_KEY[:0x0E]
45+
return cls(dummy_key)
46+
47+
@dataclass
48+
class SrtpSessionKey:
49+
buf: bytes
50+
len: int
51+
tag: bytes
52+
53+
def __init__(self, key: bytes):
54+
self.buf = key
55+
self.len = len(key)
56+
self.tag = 1
57+
58+
class SrtpSessionKeys:
59+
def __init__(self, session_keys: List[SrtpSessionKey]):
60+
assert len(session_keys) == 3
61+
self.session_key_1 = session_keys[0]
62+
self.session_key_2 = session_keys[1]
63+
self.session_key_3 = session_keys[2]
64+
65+
@property
66+
def aes_gcm_key(self) -> bytes:
67+
return self.session_key_1.buf
68+
69+
@property
70+
def nonce_key(self) -> bytes:
71+
return self.session_key_3.buf
72+
73+
class SrtpContext:
74+
_backend = default_backend()
75+
76+
def __init__(self, master_keys: SrtpMasterKeys):
77+
"""
78+
MS-SRTP context
79+
"""
80+
self.master_keys = master_keys
81+
self.session_keys = SrtpContext._derive_session_keys(
82+
self.master_keys.key1_buf, self.master_keys.key2_buf
83+
)
84+
85+
# Set-up GCM crypto instances
86+
self.decryptor_ctx = SrtpContext._init_gcm_cryptor(self.session_keys.aes_gcm_key)
87+
self.decryptor_ctx = SrtpContext._init_gcm_cryptor(self.session_keys.aes_gcm_key)
88+
89+
@classmethod
90+
def from_base64(cls, master_key_b64: str):
91+
return cls(
92+
SrtpMasterKeys.from_base64(master_key_b64)
93+
)
1294

13-
def decrypt(self, rtp_data: RtpPacket) -> RtpPacket:
14-
raise NotImplementedError('Decryption not implemented')
95+
@classmethod
96+
def from_bytes(cls, master_key: bytes):
97+
return cls(
98+
SrtpMasterKeys(master_key)
99+
)
100+
101+
@staticmethod
102+
def _derive_single_key(input_key: bytes, bitmask: int = 0) -> bytes:
103+
keysize = len(input_key)
104+
keyout = bytearray(b'\x00' * 16)
105+
106+
if keysize >= 14:
107+
keysize = 14
108+
109+
if keysize:
110+
keyout[13] = input_key[keysize - 1]
111+
if keysize != 1:
112+
keyout[12] = input_key[keysize - 2]
113+
if keysize >= 3:
114+
key_index = 0
115+
for _ in range(2, keysize):
116+
keyout[key_index + 11] = input_key[key_index + keysize - 3]
117+
key_index = key_index - 1
118+
119+
if keysize <= 13:
120+
null_count = 14 - keysize
121+
for i in range(0, null_count):
122+
keyout[i] = 0
123+
124+
for index in range(14, 16):
125+
keyout[index] = 0
126+
127+
if bitmask:
128+
len_before_xor = len(keyout)
129+
value_to_xor = struct.unpack_from('<I', keyout, 4)[0]
130+
value_to_xor ^= bitmask
131+
keyout = keyout[:4] + struct.pack('<I', value_to_xor) + keyout[8:]
132+
assert len(keyout) == len_before_xor
133+
return keyout
134+
135+
@staticmethod
136+
def _crypt_ctr_oneshot(key: bytes, iv: bytes, plaintext: bytes, max_bytes: Optional[int] = None):
137+
"""
138+
Encrypt data with AES-CTR (one-shot)
139+
"""
140+
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
141+
encryptor = cipher.encryptor()
142+
cipher_out = encryptor.update(plaintext) + encryptor.finalize()
143+
if max_bytes:
144+
# Trim to desired output
145+
cipher_out = cipher_out[:max_bytes]
146+
return cipher_out
147+
148+
@staticmethod
149+
def _derive_session_keys(key1: bytes, key2: bytes) -> SrtpSessionKeys:
150+
session1 = SrtpContext._derive_single_key(key2)
151+
session2 = SrtpContext._derive_single_key(key2, 0x1000000)
152+
session3 = SrtpContext._derive_single_key(key2, 0x2000000)
153+
154+
session1 = SrtpContext._crypt_ctr_oneshot(key1, session1, b'\x00' * 16)
155+
session2 = SrtpContext._crypt_ctr_oneshot(key1, session2, b'\x00' * 16)
156+
session3 = SrtpContext._crypt_ctr_oneshot(key1, session3, b'\x00' * 16, max_bytes=14)
157+
158+
return SrtpSessionKeys([
159+
SrtpSessionKey(session1),
160+
SrtpSessionKey(session2),
161+
SrtpSessionKey(session3)
162+
])
163+
164+
@staticmethod
165+
def _init_gcm_cryptor(key: bytes) -> AESGCM:
166+
return AESGCM(key)
167+
168+
@staticmethod
169+
def _decrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes:
170+
return ctx.decrypt(nonce, data, aad)
171+
172+
@staticmethod
173+
def _encrypt(ctx: AESGCM, nonce: bytes, data: bytes, aad: bytes) -> bytes:
174+
return ctx.encrypt(nonce, data, aad)
175+
176+
def _get_transformed_nonce(self, transform_direction: TransformDirection) -> bytes:
177+
# Skip first 2 bytes of Nonce key
178+
nonce = bytearray(self.session_keys.nonce_key[2:])
179+
# TODO: Implement transform logic
180+
# FIXME: Just tranforming the Nonce to a known value for
181+
# our single test packet
182+
nonce[-1] = nonce[-1] + 1
183+
184+
return nonce
15185

16-
def decrypt_raw(self, data: bytes) -> RtpPacket:
17-
packet = RtpPacket.parse(data)
18-
return self.decrypt(packet)
186+
def decrypt(self, data: bytes, aad: bytes) -> bytes:
187+
nonce = self._get_transformed_nonce(TransformDirection.Decrypt)
188+
return SrtpContext._decrypt(self.decryptor_ctx, nonce, data, aad)
19189

20-
def encrypt(self, rtp_data: RtpPacket) -> RtpPacket:
21-
raise NotImplementedError('Encryption not implemented')
190+
def encrypt(self, data: bytes, aad: bytes) -> RtpPacket:
191+
nonce = self._get_transformed_nonce(TransformDirection.Encrypt)
192+
return SrtpContext._encrypt(self.decryptor_ctx, nonce, data, aad)

0 commit comments

Comments
 (0)