Skip to content

Commit 9fac18f

Browse files
committed
Remove derive_key from COSEKey interface.
1 parent 906a6d6 commit 9fac18f

10 files changed

Lines changed: 167 additions & 286 deletions

cwt/algs/ec2.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Any, Dict, Optional, Union
22

33
import cryptography
44
from cryptography.hazmat.primitives import hashes
@@ -18,15 +18,12 @@
1818
COSE_ALGORITHMS_CKDM_KEY_AGREEMENT_ES,
1919
COSE_ALGORITHMS_HPKE,
2020
COSE_ALGORITHMS_SIG_EC2,
21-
COSE_KEY_LEN,
2221
COSE_KEY_OPERATION_VALUES,
2322
COSE_KEY_TYPES,
2423
)
25-
from ..cose_key_interface import COSEKeyInterface
2624
from ..exceptions import EncodeError, VerifyError
27-
from ..utils import i2osp, os2ip, to_cis
25+
from ..utils import i2osp, os2ip
2826
from .asymmetric import AsymmetricKey
29-
from .symmetric import AESCCMKey, AESGCMKey, ChaCha20Key, HMACKey
3027

3128

3229
class EC2Key(AsymmetricKey):
@@ -273,13 +270,7 @@ def verify(self, msg: bytes, sig: bytes):
273270
except ValueError as err:
274271
raise VerifyError("Invalid signature.") from err
275272

276-
def derive_bytes(
277-
self,
278-
length: int,
279-
material: bytes = b"",
280-
info: bytes = b"",
281-
public_key: Optional[Any] = None,
282-
) -> bytes:
273+
def derive_bytes(self, length: int, material: bytes = b"", info: bytes = b"", public_key: Optional[Any] = None) -> bytes:
283274

284275
if self._public_key:
285276
raise ValueError("Public key cannot be used for key derivation.")
@@ -304,52 +295,6 @@ def derive_bytes(
304295
except Exception as err:
305296
raise EncodeError("Failed to derive bytes.") from err
306297

307-
def derive_key(
308-
self,
309-
context: Union[List[Any], Dict[str, Any]],
310-
material: bytes = b"",
311-
public_key: Optional[COSEKeyInterface] = None,
312-
) -> COSEKeyInterface:
313-
314-
if self._public_key:
315-
raise ValueError("Public key cannot be used for key derivation.")
316-
if not public_key:
317-
raise ValueError("public_key should be set.")
318-
if not isinstance(public_key.key, EllipticCurvePublicKey):
319-
raise ValueError("public_key should be elliptic curve public key.")
320-
if self._alg not in COSE_ALGORITHMS_CKDM_KEY_AGREEMENT.values():
321-
raise ValueError(f"Invalid alg for key derivation: {self._alg}.")
322-
323-
# Validate context information.
324-
if isinstance(context, dict):
325-
context = to_cis(context, self._alg)
326-
else:
327-
self._validate_context(context)
328-
329-
# Derive key.
330-
self._key = self._private_key if self._private_key else ec.generate_private_key(self._crv_obj)
331-
shared_key = self._key.exchange(ec.ECDH(), public_key.key)
332-
hkdf = HKDF(
333-
algorithm=self._hash_alg(),
334-
length=COSE_KEY_LEN[context[0]] // 8,
335-
salt=None,
336-
info=self._dumps(context),
337-
)
338-
# return COSEKey.from_symmetric_key(hkdf.derive(shared_key), alg=context[0])
339-
cose_key = {
340-
1: 4,
341-
3: context[0],
342-
-1: hkdf.derive(shared_key),
343-
}
344-
if cose_key[3] in [1, 2, 3]:
345-
return AESGCMKey(cose_key)
346-
if cose_key[3] in [4, 5, 6, 7]:
347-
return HMACKey(cose_key)
348-
if cose_key[3] in [10, 11, 12, 13, 30, 31, 32, 33]:
349-
return AESCCMKey(cose_key)
350-
# cose_key[3] == 24:
351-
return ChaCha20Key(cose_key)
352-
353298
def _der_to_os(self, key_size: int, sig: bytes) -> bytes:
354299
num_bytes = (key_size + 7) // 8
355300
r, s = decode_dss_signature(sig)

cwt/algs/okp.py

Lines changed: 4 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Any, Dict, Optional, Union
22

33
import cryptography
44
from cryptography.hazmat.primitives import hashes
@@ -23,20 +23,16 @@
2323
PublicFormat,
2424
)
2525

26-
from ..const import ( # COSE_KEY_LEN,
26+
from ..const import (
2727
COSE_ALGORITHMS_CKDM_KEY_AGREEMENT,
2828
COSE_ALGORITHMS_CKDM_KEY_AGREEMENT_ES,
2929
COSE_ALGORITHMS_HPKE,
3030
COSE_ALGORITHMS_SIG_OKP,
31-
COSE_KEY_LEN,
3231
COSE_KEY_OPERATION_VALUES,
3332
COSE_KEY_TYPES,
3433
)
35-
from ..cose_key_interface import COSEKeyInterface
3634
from ..exceptions import EncodeError, VerifyError
37-
from ..utils import to_cis
3835
from .asymmetric import AsymmetricKey
39-
from .symmetric import AESCCMKey, AESGCMKey, ChaCha20Key, HMACKey
4036

4137

4238
class OKPKey(AsymmetricKey):
@@ -273,13 +269,7 @@ def verify(self, msg: bytes, sig: bytes):
273269
except cryptography.exceptions.InvalidSignature as err:
274270
raise VerifyError("Failed to verify.") from err
275271

276-
def derive_bytes(
277-
self,
278-
length: int,
279-
material: bytes = b"",
280-
info: bytes = b"",
281-
public_key: Optional[Any] = None,
282-
) -> bytes:
272+
def derive_bytes(self, length: int, material: bytes = b"", info: bytes = b"", public_key: Optional[Any] = None) -> bytes:
283273

284274
if self._public_key:
285275
raise ValueError("Public key cannot be used for key derivation.")
@@ -297,60 +287,7 @@ def derive_bytes(
297287
else:
298288
self._key = X25519PrivateKey.generate() if self._crv == 4 else X448PrivateKey.generate()
299289
shared_key = self._key.exchange(public_key.key)
300-
hkdf = HKDF(
301-
algorithm=self._hash_alg(),
302-
length=length,
303-
salt=None,
304-
info=info,
305-
)
290+
hkdf = HKDF(algorithm=self._hash_alg(), length=length, salt=None, info=info)
306291
return hkdf.derive(shared_key)
307292
except Exception as err:
308293
raise EncodeError("Failed to derive bytes.") from err
309-
310-
def derive_key(
311-
self,
312-
context: Union[List[Any], Dict[str, Any]],
313-
material: bytes = b"",
314-
public_key: Optional[COSEKeyInterface] = None,
315-
) -> COSEKeyInterface:
316-
317-
if self._public_key:
318-
raise ValueError("Public key cannot be used for key derivation.")
319-
if not public_key:
320-
raise ValueError("public_key should be set.")
321-
if not isinstance(public_key.key, X25519PublicKey) and not isinstance(public_key.key, X448PublicKey):
322-
raise ValueError("public_key should be x25519/x448 public key.")
323-
# if self._alg not in COSE_ALGORITHMS_CKDM_KEY_AGREEMENT.values():
324-
# raise ValueError(f"Invalid alg for key derivation: {self._alg}.")
325-
326-
# Validate context information.
327-
if isinstance(context, dict):
328-
context = to_cis(context, self._alg)
329-
else:
330-
self._validate_context(context)
331-
332-
# Derive key.
333-
if self._private_key:
334-
self._key = self._private_key
335-
else:
336-
self._key = X25519PrivateKey.generate() if self._crv == 4 else X448PrivateKey.generate()
337-
shared_key = self._key.exchange(public_key.key)
338-
hkdf = HKDF(
339-
algorithm=self._hash_alg(),
340-
length=COSE_KEY_LEN[context[0]] // 8,
341-
salt=None,
342-
info=self._dumps(context),
343-
)
344-
cose_key = {
345-
1: 4,
346-
3: context[0],
347-
-1: hkdf.derive(shared_key),
348-
}
349-
if cose_key[3] in [1, 2, 3]:
350-
return AESGCMKey(cose_key)
351-
if cose_key[3] in [4, 5, 6, 7]:
352-
return HMACKey(cose_key)
353-
if cose_key[3] in [10, 11, 12, 13, 30, 31, 32, 33]:
354-
return AESCCMKey(cose_key)
355-
# cose_key[3] == 24:
356-
return ChaCha20Key(cose_key)

cwt/cose_key_interface.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def derive_bytes(
277277
public_key: Optional[Any] = None,
278278
) -> bytes:
279279
"""
280-
Derives a key with a key material or key exchange.
280+
Derives a byte string with a key material or key exchange.
281281
282282
Args:
283283
length (int): The length of derived byte string.
@@ -292,26 +292,3 @@ def derive_bytes(
292292
EncodeError: Failed to derive key.
293293
"""
294294
raise NotImplementedError
295-
296-
def derive_key(
297-
self,
298-
context: Union[List[Any], Dict[str, Any]],
299-
material: bytes = b"",
300-
public_key: Optional[Any] = None,
301-
) -> Any:
302-
"""
303-
Derives a key with a key material or key exchange.
304-
305-
Args:
306-
context (Union[List[Any], Dict[str, Any]]): Context information structure for
307-
key derivation functions.
308-
material (bytes): A key material as bytes.
309-
public_key: A public key for key derivation with key exchange.
310-
Returns:
311-
COSEKeyInterface: A COSE key derived.
312-
Raises:
313-
NotImplementedError: Not implemented.
314-
ValueError: Invalid arguments.
315-
EncodeError: Failed to derive key.
316-
"""
317-
raise NotImplementedError

cwt/recipient_algs/direct_hkdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def encode(self, plaintext: bytes = b"", aad: bytes = b"") -> Tuple[List[Any], O
8080
info=self._dumps(self._context),
8181
)
8282
derived = hkdf.derive(plaintext)
83-
return self.to_list(), COSEKey.from_symmetric_key(derived, self._context[0], self._kid)
83+
return self.to_list(), COSEKey.from_symmetric_key(derived, self._context[0], kid=self._kid)
8484
except Exception as err:
8585
raise EncodeError("Failed to derive key.") from err
8686

cwt/recipient_algs/ecdh_aes_key_wrap.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from cryptography.hazmat.primitives.keywrap import aes_key_unwrap, aes_key_wrap
44

55
from ..algs.ec2 import EC2Key
6-
from ..const import COSE_KEY_OPERATION_VALUES
6+
from ..const import COSE_KEY_LEN, COSE_KEY_OPERATION_VALUES
77
from ..cose_key import COSEKey
88
from ..cose_key_interface import COSEKeyInterface
99
from ..exceptions import DecodeError, EncodeError
@@ -56,7 +56,13 @@ def encode(self, plaintext: bytes = b"", aad: bytes = b"") -> Tuple[List[Any], O
5656
# ECDH-SS (alg=-32, -33, -34)
5757
if not self._sender_key:
5858
raise ValueError("sender_key should be set in advance.")
59-
wrapping_key = self._sender_key.derive_key(self._context, public_key=self._recipient_key)
59+
wrapping_bytes = self._sender_key.derive_bytes(
60+
COSE_KEY_LEN[self._context[0]] // 8,
61+
info=self._dumps(self._context),
62+
public_key=self._recipient_key,
63+
)
64+
wrapping_key = COSEKey.from_symmetric_key(wrapping_bytes, alg=self._context[0])
65+
# wrapping_key = self._sender_key.derive_key(self._context, public_key=self._recipient_key)
6066
if self._alg in [-29, -30, -31]:
6167
# ECDH-ES
6268
self._unprotected[-1] = self._to_cose_key(self._sender_key.key.public_key())
@@ -78,8 +84,14 @@ def decode(
7884
raise ValueError("sender_public_key should be set.")
7985

8086
try:
81-
derived = key.derive_key(self._context, public_key=self._sender_public_key)
82-
derived_bytes = aes_key_unwrap(derived.key, self._ciphertext)
87+
wrapping_key_bytes = key.derive_bytes(
88+
COSE_KEY_LEN[self._context[0]] // 8,
89+
info=self._dumps(self._context),
90+
public_key=self._sender_public_key,
91+
)
92+
wrapping_key = COSEKey.from_symmetric_key(wrapping_key_bytes, alg=self._context[0])
93+
# derived = key.derive_key(self._context, public_key=self._sender_public_key)
94+
derived_bytes = aes_key_unwrap(wrapping_key.key, self._ciphertext)
8395
except Exception as err:
8496
raise DecodeError("Failed to decode key.") from err
8597
if not as_cose_key:

cwt/recipient_algs/ecdh_direct_hkdf.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ def encode(self, plaintext: bytes = b"", aad: bytes = b"") -> Tuple[List[Any], O
7878
if not self._sender_key:
7979
raise ValueError("sender_key should be set in advance.")
8080

81-
derived_key = self._sender_key.derive_key(self._context, public_key=self._recipient_key)
81+
derived_bytes = self._sender_key.derive_bytes(
82+
COSE_KEY_LEN[self._context[0]] // 8,
83+
info=self._dumps(self._context),
84+
public_key=self._recipient_key,
85+
)
86+
derived_key = COSEKey.from_symmetric_key(derived_bytes, alg=self._context[0])
8287
if self._alg in [-25, -26]:
8388
# ECDH-ES
8489
self._unprotected[-1] = self._to_cose_key(self._sender_key.key.public_key())
@@ -97,12 +102,13 @@ def decode(
97102
if not self._sender_public_key:
98103
raise ValueError("sender_public_key should be set.")
99104
try:
100-
if not as_cose_key:
101-
return key.derive_bytes(
102-
COSE_KEY_LEN[self._context[0]] // 8,
103-
info=self._dumps(self._context),
104-
public_key=self._sender_public_key,
105-
)
106-
return key.derive_key(self._context, public_key=self._sender_public_key)
105+
derived_bytes = key.derive_bytes(
106+
COSE_KEY_LEN[self._context[0]] // 8,
107+
info=self._dumps(self._context),
108+
public_key=self._sender_public_key,
109+
)
107110
except Exception as err:
108111
raise DecodeError("Failed to decode.") from err
112+
if not as_cose_key:
113+
return derived_bytes
114+
return COSEKey.from_symmetric_key(derived_bytes, alg=self._context[0], kid=self._kid)

0 commit comments

Comments
 (0)