Skip to content

Commit 70e4357

Browse files
committed
fix(hpke): require psk_id in protected header and update test vectors
1 parent 798dcb2 commit 70e4357

5 files changed

Lines changed: 151 additions & 148 deletions

File tree

cwt/cose.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,8 +699,7 @@ def _validate_cose_message(
699699
if not isinstance(v, (bytes, bytearray)):
700700
raise ValueError("ek (-4) must be bstr.")
701701
if k == -5: # psk_id
702-
if not isinstance(v, (bytes, bytearray)):
703-
raise ValueError("psk_id (-5) must be bstr.")
702+
raise ValueError("psk_id (-5) must be placed only in the protected header.")
704703
h[k] = v
705704
if len(h) != len(p) + len(u):
706705
raise ValueError("The same keys are both in protected and unprotected headers.")

cwt/recipient_algs/hpke.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def encode(self, plaintext: bytes = b"", aad: bytes = b"") -> Tuple[List[Any], O
102102
raise ValueError("recipient_key should be set in advance.")
103103
self._kem_key = self._to_kem_key(self._recipient_key)
104104
try:
105-
psk_id = self._unprotected.get(-5, None)
105+
# psk_id MUST be in the protected header (draft-ietf-cose-hpke)
106+
psk_id = self._protected.get(-5, None) if isinstance(self._protected, dict) else None
106107
if psk_id is not None and not isinstance(psk_id, (bytes, bytearray)):
107108
raise EncodeError("psk_id (-5) must be bstr.")
108109
if self._psk is not None and psk_id is None:
@@ -147,7 +148,8 @@ def decode(
147148
if not isinstance(ek, (bytes, bytearray)):
148149
raise DecodeError("ek (-4) must be bstr.")
149150
try:
150-
psk_id = self._unprotected.get(-5, None)
151+
# psk_id MUST be in the protected header (draft-ietf-cose-hpke)
152+
psk_id = self._protected.get(-5, None) if isinstance(self._protected, dict) else None
151153
if psk_id is not None and not isinstance(psk_id, (bytes, bytearray)):
152154
raise DecodeError("psk_id (-5) must be bstr.")
153155
if self._psk is not None and psk_id is None:

tests/test_cose_hpke.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ def test_cose_hpke_encrypt0_psk_id_wrong_type_header_validation(self):
186186
sender.encode_and_encrypt(
187187
b"This is the content.",
188188
rpk,
189-
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0},
190-
unprotected={COSEHeaders.KID: b"01", COSEHeaders.PSK_ID: 123},
189+
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0, COSEHeaders.PSK_ID: 123},
190+
unprotected={COSEHeaders.KID: b"01"},
191191
)
192192
assert "psk_id (-5) must be bstr." in str(err.value)
193193

@@ -205,8 +205,8 @@ def test_cose_hpke_encrypt0_with_psk_id_roundtrip(self):
205205
encoded = sender.encode_and_encrypt(
206206
b"This is the content.",
207207
rpk,
208-
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0},
209-
unprotected={COSEHeaders.KID: b"01", COSEHeaders.PSK_ID: b"psk-01"},
208+
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0, COSEHeaders.PSK_ID: b"psk-01"},
209+
unprotected={COSEHeaders.KID: b"01"},
210210
hpke_psk=b"secret-psk",
211211
)
212212

@@ -279,8 +279,8 @@ def test_cose_hpke_encrypt0_psk_id_without_psk_should_error_on_encode(self):
279279
sender.encode_and_encrypt(
280280
b"This is the content.",
281281
rpk,
282-
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0},
283-
unprotected={COSEHeaders.KID: b"01", COSEHeaders.PSK_ID: b"psk-01"},
282+
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0, COSEHeaders.PSK_ID: b"psk-01"},
283+
unprotected={COSEHeaders.KID: b"01"},
284284
)
285285
assert "hpke_psk is required when psk_id (-5) is provided." in str(err.value)
286286

@@ -295,7 +295,8 @@ def test_cose_hpke_encrypt0_psk_id_without_psk_should_error_on_decode(self):
295295
}
296296
)
297297
sender = COSE.new()
298-
# First, produce a base-mode (no psk_id) and then inject psk_id to simulate peer mismatch
298+
# First, produce a base-mode (no psk_id) and then inject psk_id into the
299+
# protected header to simulate peer mismatch
299300
encoded = sender.encode_and_encrypt(
300301
b"This is the content.",
301302
rpk,
@@ -304,8 +305,10 @@ def test_cose_hpke_encrypt0_psk_id_without_psk_should_error_on_decode(self):
304305
)
305306
tag = cbor2.loads(encoded)
306307
p, u, c = tag.value
307-
u[-5] = b"psk-01"
308-
tampered = cbor2.dumps(cbor2.CBORTag(16, [p, u, c]))
308+
p_map = cbor2.loads(p)
309+
p_map[-5] = b"psk-01"
310+
tampered_p = cbor2.dumps(p_map)
311+
tampered = cbor2.dumps(cbor2.CBORTag(16, [tampered_p, u, c]))
309312

310313
rsk = COSEKey.from_jwk(
311314
{
@@ -676,8 +679,8 @@ def test_cose_hpke_ke_with_psk_roundtrip(self):
676679
)
677680

678681
r = Recipient.new(
679-
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0_KE},
680-
unprotected={COSEHeaders.KID: b"01", COSEHeaders.PSK_ID: b"psk-01"},
682+
protected={COSEHeaders.ALG: COSEAlgs.HPKE_0_KE, COSEHeaders.PSK_ID: b"psk-01"},
683+
unprotected={COSEHeaders.KID: b"01"},
681684
recipient_key=rpk,
682685
hpke_psk=b"secret-psk",
683686
)

0 commit comments

Comments
 (0)