Skip to content

Commit 6b26eed

Browse files
committed
added number of shots for fullqst
1 parent a5bebe2 commit 6b26eed

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

tests/test_tomography.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@ def setUp(self):
3838

3939
def test_full_qst(self):
4040
backend = Aer.get_backend("statevector_simulator")
41-
full_qst = FullQST(self.ansatz, backend)
42-
sol = full_qst.get_relative_amplitude_sign(self.parameters)
41+
full_qst = FullQST(self.ansatz, backend, shots=10000)
42+
_ = full_qst.get_relative_amplitude_sign(self.parameters)
4343
# assert np.allclose(self.ref, sol) or np.allclose(self.ref, -sol)
4444

4545
def test_htree_qst(self):
4646
sampler = Sampler()
4747
htree_qst = HTreeQST(self.ansatz, sampler)
4848
sol = htree_qst.get_relative_amplitude_sign(self.parameters)
49-
# assert np.allclose(self.ref, sol) or np.allclose(self.ref, -sol)
49+
assert np.allclose(self.ref, sol) or np.allclose(self.ref, -sol)
5050

5151
def test_shadow_qst(self):
5252
sampler = Sampler()
5353
shadow_qst = ShadowQST(self.ansatz, sampler, 10000)
5454
sol = shadow_qst.get_relative_amplitude_sign(self.parameters)
55-
# assert np.allclose(self.ref, sol) or np.allclose(self.ref, -sol)
55+
assert np.allclose(self.ref, sol) or np.allclose(self.ref, -sol)

vqls_prototype/tomography/qst.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44

55
class FullQST:
6-
def __init__(self, circuit, backend):
6+
def __init__(self, circuit, backend, shots=1000):
77
self.backend = backend
88
self.circuit = circuit
9+
self.shots = shots
910

1011
def get_relative_amplitude_sign(self, parameters):
1112
"""_summary_
@@ -50,7 +51,7 @@ def get_density_matrix(self, parameters):
5051
_type_: _description_
5152
"""
5253
qstexp1 = StateTomography(self.circuit.bind_parameters(parameters))
53-
qstdata1 = qstexp1.run(self.backend).block_for_results()
54+
qstdata1 = qstexp1.run(self.backend, shots=self.shots).block_for_results()
5455
return qstdata1.analysis_results("state").value.data.real
5556

5657
def get_statevector(self, parameters):

0 commit comments

Comments
 (0)