Skip to content

Commit 187b861

Browse files
authored
examples: add a mdi-qkd simulation (#158)
This follows the presentation in https://arxiv.org/pdf/1109.1473.pdf with a slight generalization: Alice and Bob repeatedly run the protocol on a batch of qubits until they've gathered enough key material. Fixes: #131
1 parent 4c76943 commit 187b861

1 file changed

Lines changed: 223 additions & 0 deletions

File tree

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import random
2+
from enum import StrEnum, auto
3+
from dataclasses import dataclass
4+
from qunetsim.components.host import Host
5+
from qunetsim.components.network import Network
6+
from qunetsim.objects import Qubit
7+
from qunetsim.objects import Logger
8+
9+
Logger.DISABLED = True
10+
default_key_len = 32
11+
default_batch_len = 8
12+
13+
14+
class Basis(StrEnum):
15+
rect = auto()
16+
diag = auto()
17+
18+
19+
@dataclass()
20+
class BB84State():
21+
basis: Basis
22+
value: int
23+
24+
def into_qubit(self, host: Host) -> Qubit:
25+
q = Qubit(host)
26+
27+
if self.value == 1:
28+
q.X()
29+
if self.basis == Basis.diag:
30+
q.H()
31+
32+
return q
33+
34+
35+
class _BB84StateSpace():
36+
states = [BB84State(basis=b, value=v)
37+
for b in (Basis.diag, Basis.rect)
38+
for v in (0, 1)]
39+
40+
def __len__(self):
41+
return len(self.states)
42+
43+
def __getitem__(self, i):
44+
return self.states[i]
45+
46+
47+
# Treat this as a singleton.
48+
BB84StateSpace = _BB84StateSpace()
49+
50+
51+
# Helper class for Bell states in the computational basis
52+
class RelayMessage(StrEnum):
53+
phi_minus = auto()
54+
phi_plus = auto()
55+
psi_plus = auto()
56+
psi_minus = auto()
57+
58+
@staticmethod
59+
def from_measurement(x):
60+
m = {
61+
(0, 0): RelayMessage.phi_plus,
62+
(1, 0): RelayMessage.phi_minus,
63+
(0, 1): RelayMessage.psi_plus,
64+
(1, 1): RelayMessage.psi_minus,
65+
}
66+
return m[x]
67+
68+
69+
@dataclass()
70+
class MDIRelay():
71+
node: Host
72+
alice: Host
73+
bob: Host
74+
75+
batch_len: int
76+
77+
def measure(a: Qubit, b: Qubit):
78+
"""
79+
Perform a Bell measurement on the given qubit pair and return the
80+
result, encoded as a RelayMessage.
81+
"""
82+
83+
a.cnot(b)
84+
a.H()
85+
m = (a.measure(), b.measure())
86+
return RelayMessage.from_measurement(m)
87+
88+
def protocol(self, charlie):
89+
# collect qubits, measure, broadcast results
90+
while True:
91+
a, b, m = [], [], []
92+
for k in range(self.batch_len):
93+
a.append(charlie.get_qubit(self.alice.host_id, wait=-1))
94+
b.append(charlie.get_qubit(self.bob.host_id, wait=-1))
95+
m.append(MDIRelay.measure(a[k], b[k]))
96+
97+
charlie.send_broadcast(" ".join(m))
98+
99+
def run(self):
100+
return self.node.run_protocol(self.protocol)
101+
102+
103+
@dataclass()
104+
class MDINode():
105+
node: Host
106+
peer: Host
107+
relay: Host
108+
flip: bool
109+
110+
key_len: int
111+
batch_len: int
112+
113+
def protocol(self, node):
114+
key = []
115+
116+
while True:
117+
print(f"key_{node.host_id:5} = {key[:self.key_len]}")
118+
if len(key) >= self.key_len:
119+
return
120+
121+
# Quantum communication phase. Gather a batch of qubits and send
122+
# them to the relay.
123+
states, qbits = [], []
124+
for k in range(self.batch_len):
125+
states.append(random.choice(BB84StateSpace))
126+
qbits.append(states[k].into_qubit(node))
127+
128+
for q in qbits:
129+
node.send_qubit(self.relay.host_id, q, await_ack=True)
130+
131+
# Classical communication phase.
132+
msg = node.get_next_classical(self.relay.host_id)
133+
measurements = msg.content.split()
134+
135+
# Post selection
136+
# Find states that led to successful measurements.
137+
good_states = []
138+
good_measurements = []
139+
for k, (s, m) in enumerate(zip(states, measurements)):
140+
if m in {RelayMessage.psi_plus, RelayMessage.psi_minus}:
141+
good_states.append(s)
142+
good_measurements.append(m)
143+
144+
# Send peer our basis sequence for successful measurements.
145+
node.send_classical(self.peer.host_id,
146+
" ".join([s.basis for s in good_states]))
147+
148+
# Post² selection: select the bits where node and peer used the
149+
# same basis.
150+
msg = node.get_next_classical(self.peer.host_id)
151+
peer_bases = msg.content.split()
152+
153+
for k in range(len(good_measurements)):
154+
if good_states[k].basis == peer_bases[k]:
155+
key.append(good_states[k].value)
156+
157+
if self.flip and not (
158+
good_measurements[k] == RelayMessage.psi_plus
159+
and good_states[k].basis == Basis.diag
160+
):
161+
key[-1] ^= 1
162+
163+
def run(self):
164+
return self.node.run_protocol(self.protocol)
165+
166+
167+
@dataclass(init=False)
168+
class MDINetwork():
169+
alice: MDINode
170+
bob: MDINode
171+
charlie: MDIRelay
172+
network: Network
173+
174+
def __init__(self, key_len=default_key_len,
175+
batch_len=default_batch_len):
176+
self.network = Network.get_instance()
177+
nodes = ['Alice', 'Bob', 'Charlie']
178+
self.network.start(nodes)
179+
180+
alice = Host('Alice')
181+
bob = Host('Bob')
182+
charlie = Host('Charlie')
183+
184+
alice.add_connections(['Bob', 'Charlie'])
185+
bob.add_connections(['Alice', 'Charlie'])
186+
charlie.add_connections(['Alice', 'Bob'])
187+
188+
self.network.delay = 0.1
189+
bob.start()
190+
alice.start()
191+
self.network.delay = 0.2
192+
charlie.start()
193+
194+
self.network.add_host(alice)
195+
self.network.add_host(bob)
196+
self.network.add_host(charlie)
197+
198+
self.alice = MDINode(node=alice, peer=bob, relay=charlie,
199+
flip=False, key_len=key_len, batch_len=batch_len)
200+
self.bob = MDINode(node=bob, peer=alice, relay=charlie,
201+
flip=True, key_len=key_len, batch_len=batch_len)
202+
self.charlie = MDIRelay(node=charlie,
203+
alice=alice, bob=bob, batch_len=batch_len)
204+
205+
def simulate(self):
206+
t1 = self.alice.run()
207+
t2 = self.bob.run()
208+
_ = self.charlie.run()
209+
210+
t1.join()
211+
t2.join()
212+
self.network.stop(True)
213+
214+
215+
def main():
216+
random.seed(2**2023)
217+
MDINetwork().simulate()
218+
219+
exit()
220+
221+
222+
if __name__ == '__main__':
223+
main()

0 commit comments

Comments
 (0)