Skip to content

Commit d8a0816

Browse files
committed
New test cases, minor refactor, more logging
1 parent 6b3f058 commit d8a0816

14 files changed

Lines changed: 248 additions & 138 deletions

File tree

client/client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from common import utils
1212
from common.logging import LOGGER
1313
from common.user_key import UserKey
14+
from common.utils import key_id_str_to_uuid
1415
from .peer_hub import PeerHub
1516

1617

@@ -125,15 +126,12 @@ async def etsi_get_key(
125126
}
126127

127128
async def etsi_get_key_with_key_ids(
128-
self, master_sae_id: str, slave_sae_id: str, key_id: str
129+
self, master_sae_id: str, slave_sae_id: str, key_id_str: str
129130
):
130131
"""
131132
ETSI QKD 014 V1.1.1 Get key with key IDs API.
132133
"""
133-
try:
134-
key_id = UUID(key_id)
135-
except ValueError as exc:
136-
raise exceptions.InvalidKeyIDError(key_id) from exc
134+
key_id = key_id_str_to_uuid(key_id_str)
137135
key = await self.gather_key_from_peer_hubs(master_sae_id, slave_sae_id, key_id)
138136
return {
139137
"keys": [

client/peer_hub.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from common.block import APIBlock, Block
1010
from common.encryption_key import EncryptionKey
1111
from common.logging import LOGGER
12+
from common.owner import Owner
1213
from common.pool import Pool
1314
from common.registration_api import (
1415
APIPutRegistrationRequest,
@@ -48,8 +49,8 @@ def __init__(self, client, base_url):
4849
self._base_url = self._base_url[:-1]
4950
self._registered = False
5051
hub_name = base_url.split("/")[-1]
51-
self._local_pool = Pool(hub_name, Pool.Owner.LOCAL)
52-
self._peer_pool = Pool(hub_name, Pool.Owner.PEER)
52+
self._local_pool = Pool(hub_name, Owner.LOCAL)
53+
self._peer_pool = Pool(hub_name, Owner.PEER)
5354
self._register_task = None
5455
self._local_pool_request_psrd_task = None
5556
self._peer_pool_request_psrd_task = None
@@ -165,9 +166,9 @@ async def request_psrd_task(self, pool: Pool) -> None:
165166
LOGGER.info(f"Finish {task_name}")
166167
finally:
167168
match pool.owner:
168-
case Pool.Owner.LOCAL:
169+
case Owner.LOCAL:
169170
self._local_pool_request_psrd_task = None
170-
case Pool.Owner.PEER:
171+
case Owner.PEER:
171172
self._peer_pool_request_psrd_task = None
172173

173174
async def attempt_request_psrd(self, pool: Pool) -> bool:
@@ -177,11 +178,7 @@ async def attempt_request_psrd(self, pool: Pool) -> bool:
177178
"""
178179
assert self._registered
179180
url = f"{self._base_url}/dske/oob/v1/psrd"
180-
match pool.owner:
181-
case Pool.Owner.LOCAL:
182-
owner_str = "client"
183-
case Pool.Owner.PEER:
184-
owner_str = "hub"
181+
owner_str = pool.owner.to_str(local_name="client", peer_name="hub")
185182
params = {
186183
"client_name": self._client.name,
187184
"owner": owner_str,

common/block.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pydantic
99
from bitarray import bitarray
1010
from common.fragment import Fragment
11+
from common.logging import LOGGER
1112
from common.utils import bytes_to_str, str_to_bytes
1213
from common.exceptions import (
1314
InvalidBlockUUIDError,
@@ -147,9 +148,14 @@ def take_data(self, start: int, size: int) -> bytes:
147148
Take data from the block at the specified start byte and size.
148149
"""
149150
end = start + size
150-
if start < 0 or end > self._size:
151+
if start < 0:
152+
LOGGER.error(f"Take data from blocK: invalid start index {start}")
151153
raise InvalidPSRDIndex(self._block_uuid, start)
154+
if end > self._size:
155+
LOGGER.error(f"Take data from block: invalid end index {end}")
156+
raise InvalidPSRDIndex(self._block_uuid, end)
152157
if self._used[start:end].any():
158+
LOGGER.error(f"Take data from block: already in use {start}:{end}")
153159
raise PSRDDataAlreadyUsedError(self._block_uuid, start, size)
154160
self._used[start:end] = True
155161
data = self._data[start:end]
@@ -185,10 +191,12 @@ def from_api(cls, api_block: APIBlock) -> "Block":
185191
try:
186192
block_uuid = UUID(api_block.block_uuid)
187193
except ValueError as exc:
194+
LOGGER.error(f"Invalid block UUID in API block: {api_block.block_uuid}")
188195
raise InvalidBlockUUIDError(api_block.block_uuid) from exc
189196
try:
190197
data = str_to_bytes(api_block.data)
191198
except Exception as exc:
199+
LOGGER.error(f"Invalid block data in API block: {api_block.data}")
192200
raise InvalidPSRDDataError from exc
193201
return Block(block_uuid, data)
194202

common/exceptions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -353,15 +353,15 @@ def __init__(self, encoded_signing_key: str):
353353
)
354354

355355

356-
class EncryptorNotRegisteredForClientError(DSKEException):
356+
class EncryptorNotConnectedToClientError(DSKEException):
357357
"""
358358
Exception raised when an encryptor is not registered for a client.
359359
"""
360360

361361
def __init__(self, client_name: str, encryptor_name: str):
362362
super().__init__(
363363
status_code=status.HTTP_400_BAD_REQUEST,
364-
message="Encryptor is not registered for client.",
364+
message="Encryptor is not connected to client.",
365365
details={
366366
"client_name": client_name,
367367
"encryptor_name": encryptor_name,
@@ -390,13 +390,13 @@ class WrongMasterSAEIDError(DSKEException):
390390
the master SAE ID used in the original Get key request.
391391
"""
392392

393-
def __init__(self, client_name: str, master_sae_id: str, key_id: str):
393+
def __init__(self, requested_master_sae_id: str, share_master_sae_id, key_id: str):
394394
super().__init__(
395395
status_code=status.HTTP_400_BAD_REQUEST,
396396
message="Master SAE ID does not match the one used in the original Get key request.",
397397
details={
398-
"client_name": client_name,
399-
"master_sae_id": master_sae_id,
398+
"requested_master_sae_id": requested_master_sae_id,
399+
"share_master_sae_id": share_master_sae_id,
400400
"key_id": key_id,
401401
},
402402
)
@@ -408,13 +408,13 @@ class WrongSlaveSAEIDError(DSKEException):
408408
the slave SAE ID used in the original Get key request.
409409
"""
410410

411-
def __init__(self, client_name: str, slave_sae_id: str, key_id: str):
411+
def __init__(self, requested_master_sae_id: str, share_master_sae_id, key_id: str):
412412
super().__init__(
413413
status_code=status.HTTP_400_BAD_REQUEST,
414414
message="Slave SAE ID does not match the one used in the original Get key request.",
415415
details={
416-
"client_name": client_name,
417-
"slave_sae_id": slave_sae_id,
416+
"requested_master_sae_id": requested_master_sae_id,
417+
"share_master_sae_id": share_master_sae_id,
418418
"key_id": key_id,
419419
},
420420
)

common/fragment.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from uuid import UUID
66
import pydantic
77
from common.exceptions import InvalidBlockUUIDError, InvalidEncodedFragmentError
8+
from common.logging import LOGGER
89
from . import utils
910

1011

@@ -106,6 +107,9 @@ def from_api(
106107
try:
107108
block_uuid = UUID(api_fragment.block_uuid)
108109
except ValueError as exc:
110+
LOGGER.error(
111+
"Invalid block UUID in API fragment: %s", api_fragment.block_uuid
112+
)
109113
raise InvalidBlockUUIDError(block_uuid=api_fragment.block_uuid) from exc
110114
block = pool.get_block(block_uuid)
111115
data = block.take_data(api_fragment.start, api_fragment.size)
@@ -135,20 +139,35 @@ def from_enc_str(
135139
"""
136140
parts = enc_str.split(":")
137141
if len(parts) != 3:
142+
LOGGER.error(
143+
f"Invalid encoded fragment {enc_str} "
144+
f"(expected three parts separated by :)"
145+
)
138146
raise InvalidEncodedFragmentError(encoded_fragment=enc_str)
139147
block_uuid_str, start_byte_str, size_str = parts
140148
try:
141149
block_uuid = UUID(block_uuid_str)
142150
except ValueError as exc:
151+
LOGGER.error(
152+
f"Invalid encoded fragment {enc_str} "
153+
f"(invalid block UUID {block_uuid_str})"
154+
)
143155
raise InvalidBlockUUIDError(block_uuid=block_uuid_str) from exc
144156
block = pool.get_block(block_uuid)
145157
try:
146158
start = int(start_byte_str)
147159
except ValueError as exc:
160+
LOGGER.error(
161+
f"Invalid encoded fragment {enc_str} "
162+
f"(invalid start byte {start_byte_str})"
163+
)
148164
raise InvalidEncodedFragmentError(encoded_fragment=enc_str) from exc
149165
try:
150166
size = int(size_str)
151167
except ValueError as exc:
168+
LOGGER.error(
169+
f"Invalid encoded fragment {enc_str} " f"(invalid size {size_str})"
170+
)
152171
raise InvalidEncodedFragmentError(encoded_fragment=enc_str) from exc
153172
data = block.take_data(start, size)
154173
return Fragment(block=block, start=start, size=size, data=data)

common/owner.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
Enumeration for pool ownership.
3+
"""
4+
5+
import enum
6+
from .exceptions import InvalidOwnerError
7+
8+
9+
class Owner(enum.Enum):
10+
"""
11+
Who owns a pool? The client node or the hub node? Only the owner is allowed to make
12+
allocations out of the pool. The non-owner only takes data out of the pool, but the peer
13+
decides which data is taken (i.e. the peer does the allocation).
14+
"""
15+
16+
LOCAL = 1
17+
PEER = 2
18+
19+
def __str__(self):
20+
return self.to_str()
21+
22+
def to_str(self, local_name: str = "local", peer_name: str = "peer") -> str:
23+
"""
24+
Convert the Owner to a string.
25+
"""
26+
match self:
27+
case Owner.LOCAL:
28+
return local_name
29+
case Owner.PEER:
30+
return peer_name
31+
32+
@staticmethod
33+
def from_str(
34+
owner_str: str, local_name: str = "local", peer_name: str = "peer"
35+
) -> "Owner":
36+
"""
37+
Create an Owner from a string.
38+
"""
39+
lower_owner_str = owner_str.lower()
40+
if lower_owner_str == local_name:
41+
return Owner.LOCAL
42+
if lower_owner_str == peer_name:
43+
return Owner.PEER
44+
raise InvalidOwnerError(owner_str)

common/pool.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,20 @@
22
A pool of blocks.
33
"""
44

5-
import enum
65
from uuid import UUID
76
from pydantic import PositiveInt
87
from .allocation import Allocation
98
from .block import Block
10-
from .logging import LOGGER
119
from .exceptions import OutOfPreSharedRandomDataError, InvalidBlockUUIDError
10+
from .logging import LOGGER
11+
from .owner import Owner
1212

1313

1414
class Pool:
1515
"""
1616
A pool of blocks.
1717
"""
1818

19-
class Owner(enum.Enum):
20-
"""
21-
Who owns the pool? The client node or the hub node? Only the owner is allowed to make
22-
allocations out of the pool. The non-owner only takes data out of the pool, but the peer
23-
decides which data is taken (i.e. the peer does the allocation).
24-
"""
25-
26-
LOCAL = 1
27-
PEER = 2
28-
29-
def __str__(self):
30-
return self.name.lower()
31-
3219
_name: str
3320
_blocks: list[Block]
3421
_owner: Owner
@@ -81,6 +68,7 @@ def get_block(self, block_uuid: UUID) -> Block:
8168
for block in self._blocks:
8269
if block.uuid == block_uuid:
8370
return block
71+
LOGGER.error(f"Block UUID not found in {str(self.owner)} pool: {block_uuid}")
8472
raise InvalidBlockUUIDError(block_uuid=str(block_uuid))
8573

8674
def allocate(self, size: PositiveInt, purpose: str) -> Allocation:

common/share.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44

55
from uuid import UUID
6+
from .exceptions import WrongMasterSAEIDError, WrongSlaveSAEIDError
7+
from .logging import LOGGER
68
from .utils import bytes_to_str
79

810

@@ -98,3 +100,29 @@ def to_mgmt(self):
98100
"share_index": self._share_index,
99101
"value": bytes_to_str(self._value, truncate=True),
100102
}
103+
104+
def check_master_sae(self, master_sae_id: str):
105+
"""
106+
Check if the given master SAE ID matches the stored master SAE ID.
107+
Raise an exception if not.
108+
"""
109+
if self.master_sae_id != master_sae_id:
110+
key_id_str = str(self.user_key_id)
111+
LOGGER.warning(
112+
f"Requested master SAE ID {master_sae_id} does not match master SAE ID "
113+
f"{self.master_sae_id} for key ID {key_id_str}"
114+
)
115+
raise WrongMasterSAEIDError(master_sae_id, self.master_sae_id, key_id_str)
116+
117+
def check_slave_sae(self, slave_sae_id: str):
118+
"""
119+
Check if the given slave SAE ID matches the stored slave SAE ID.
120+
Raise an exception if not.
121+
"""
122+
if self.slave_sae_id != slave_sae_id:
123+
key_id_str = str(self.user_key_id)
124+
LOGGER.warning(
125+
f"Requested slave SAE ID {slave_sae_id} does not match slave SAE ID "
126+
f"{self.slave_sae_id} for key ID {key_id_str}"
127+
)
128+
raise WrongSlaveSAEIDError(slave_sae_id, self.slave_sae_id, key_id_str)

common/tests/test_pool.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
from common.pool import Pool
88
from common.exceptions import InvalidBlockUUIDError, OutOfPreSharedRandomDataError
9+
from common.owner import Owner
910
from common.utils import bytes_to_str
1011
from .unit_test_common import create_test_pool_and_blocks
1112

@@ -14,15 +15,15 @@ def test_init():
1415
"""
1516
Initialize a pool.
1617
"""
17-
_pool = Pool(name="test_pool", owner=Pool.Owner.LOCAL)
18+
_pool = Pool(name="test_pool", owner=Owner.LOCAL)
1819

1920

2021
def test_properties():
2122
"""
2223
Properties of the pool.
2324
"""
24-
pool = Pool(name="test_pool", owner=Pool.Owner.LOCAL)
25-
assert pool.owner == Pool.Owner.LOCAL
25+
pool = Pool(name="test_pool", owner=Owner.LOCAL)
26+
assert pool.owner == Owner.LOCAL
2627

2728

2829
def test_nr_used_and_unused_bytes():

common/tests/unit_test_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from uuid import uuid4
66
from typing import List
77
from common.block import Block
8+
from common.owner import Owner
89
from common.pool import Pool
910

1011

@@ -29,7 +30,7 @@ def create_test_pool_and_blocks(block_sizes: List[int]):
2930
"""
3031
Create a test pool with blocks of the given sizes.
3132
"""
32-
pool = Pool(name="test_pool", owner=Pool.Owner.LOCAL)
33+
pool = Pool(name="test_pool", owner=Owner.LOCAL)
3334
blocks = []
3435
for block_size in block_sizes:
3536
block = create_test_block(block_size)

0 commit comments

Comments
 (0)