|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +class SpinQubitSensorEnv: |
| 4 | + def __init__(self, B_values, T2=1.0, max_steps=3): |
| 5 | + self.B_values = np.array(B_values, dtype=float) # discrete possible B values |
| 6 | + self.N = len(B_values) # number of discrete states for B |
| 7 | + self.T2 = T2 # dephasing time constant |
| 8 | + self.max_steps = max_steps # measurement count per episode |
| 9 | + self.times = None # list of measurement times (actions) |
| 10 | + self.reset() # initialize environment state |
| 11 | + |
| 12 | + def set_times(self, times): |
| 13 | + """Define the discrete measurement times corresponding to each action index.""" |
| 14 | + self.times = np.array(times, dtype=float) |
| 15 | + |
| 16 | + def reset(self): |
| 17 | + """Start a new episode: sample a true B, reset belief to uniform.""" |
| 18 | + self.true_idx = np.random.randint(0, self.N) # index of true B in the list |
| 19 | + self.true_B = self.B_values[self.true_idx] # actual magnetic field for this episode |
| 20 | + self.belief = np.ones(self.N) / self.N # uniform prior |
| 21 | + self.step_count = 0 |
| 22 | + # State includes current belief distribution and remaining steps (optional) |
| 23 | + return self._get_state() |
| 24 | + |
| 25 | + def _get_state(self): |
| 26 | + """Construct the state vector (belief + remaining steps).""" |
| 27 | + remaining_steps = self.max_steps - self.step_count |
| 28 | + # We append the remaining step count (normalized, e.g. /max_steps if needed) to the belief vector |
| 29 | + return np.concatenate([self.belief, [remaining_steps]]) |
| 30 | + |
| 31 | + def step(self, action): |
| 32 | + """Simulate one measurement step with the chosen action (time index).""" |
| 33 | + assert self.times is not None, "Measurement times not set. Call set_times(...) first." |
| 34 | + tau = float(self.times[action]) # selected measurement duration |
| 35 | + # Quantum sensor evolution: compute outcome probabilities |
| 36 | + # P(+X outcome) = 0.5 * (1 + e^{-tau/T2} * cos(B * tau)) |
| 37 | + exp_decay = np.exp(-tau / self.T2) |
| 38 | + cos_phase = np.cos(self.true_B * tau) |
| 39 | + p_plus = 0.5 * (1 + exp_decay * cos_phase) # probability of getting outcome +1 (X-basis) |
| 40 | + # Sample a measurement outcome according to this probability |
| 41 | + outcome = 1 if np.random.rand() < p_plus else 0 # 1 for +X outcome, 0 for -X outcome |
| 42 | + |
| 43 | + # Bayesian belief update: |
| 44 | + # Calculate likelihoods P(outcome | B_i) for each candidate B_i |
| 45 | + likelihoods = 0.5 * (1 + np.exp(-tau/self.T2) * np.cos(self.B_values * tau)) |
| 46 | + if outcome == 0: |
| 47 | + likelihoods = 1 - likelihoods # if we got the "-X" outcome, use 1 - p_plus probabilities |
| 48 | + # Update posterior belief via elementwise multiplication and normalization |
| 49 | + prior = self.belief |
| 50 | + unnorm_post = prior * likelihoods |
| 51 | + if unnorm_post.sum() == 0: |
| 52 | + # Numerical safety: if all probabilities zero (unlikely), keep prior |
| 53 | + post = prior |
| 54 | + else: |
| 55 | + post = unnorm_post / unnorm_post.sum() |
| 56 | + self.belief = post |
| 57 | + self.step_count += 1 |
| 58 | + |
| 59 | + # Calculate reward: information gain (reduction in entropy) |
| 60 | + def entropy(p_dist): |
| 61 | + mask = p_dist > 0 |
| 62 | + return -np.sum(p_dist[mask] * np.log2(p_dist[mask])) |
| 63 | + prev_entropy = entropy(prior) |
| 64 | + new_entropy = entropy(post) |
| 65 | + reward = prev_entropy - new_entropy # positive if uncertainty decreased |
| 66 | + |
| 67 | + # If this was the last measurement, provide final accuracy reward |
| 68 | + done = (self.step_count >= self.max_steps) |
| 69 | + if done: |
| 70 | + # Determine final estimate (most likely B) |
| 71 | + est_idx = int(np.argmax(self.belief)) |
| 72 | + if est_idx == self.true_idx: |
| 73 | + reward += 1.0 # bonus for correct identification |
| 74 | + else: |
| 75 | + reward += 0.0 # (or a small negative penalty for wrong, e.g. reward -= 0.5) |
| 76 | + |
| 77 | + # Construct next state (or None if done) |
| 78 | + next_state = self._get_state() if not done else None |
| 79 | + return next_state, reward, done, {"outcome": outcome, "true_B": self.true_B} |
| 80 | + |
| 81 | +import torch |
| 82 | +import torch.nn as nn |
| 83 | + |
| 84 | +class PolicyNet(nn.Module): |
| 85 | + def __init__(self, state_dim, action_dim, hidden_size=64): |
| 86 | + super(PolicyNet, self).__init__() |
| 87 | + self.fc1 = nn.Linear(state_dim, hidden_size) |
| 88 | + self.fc2 = nn.Linear(hidden_size, action_dim) |
| 89 | + def forward(self, state): |
| 90 | + x = torch.relu(self.fc1(state)) |
| 91 | + logits = self.fc2(x) # raw scores for each action |
| 92 | + return logits # (we'll apply softmax via PyTorch distributions when sampling) |
| 93 | + |
| 94 | + |
| 95 | + |
| 96 | +import torch.optim as optim |
| 97 | + |
| 98 | +# Initialize environment and policy |
| 99 | +B_values = [0.5, 1.0, 1.5, 2.0, 2.5] # possible magnetic field values (discrete) |
| 100 | +env = SpinQubitSensorEnv(B_values, T2=1.0, max_steps=3) |
| 101 | +env.set_times([0.1, 0.5, 1.0, 2.0, 3.0]) # define 5 possible measurement times (seconds, for example) |
| 102 | +state_dim = env.N + 1 # belief length + 1 |
| 103 | +action_dim = len(env.times) # number of discrete actions |
| 104 | +policy = PolicyNet(state_dim, action_dim) |
| 105 | +optimizer = optim.Adam(policy.parameters(), lr=0.01) |
| 106 | + |
| 107 | +# Training parameters |
| 108 | +num_episodes = 5000 |
| 109 | +gamma = 1.0 # discount factor (can be 1 for episodic tasks focusing on final outcome) |
| 110 | + |
| 111 | +for episode in range(num_episodes): |
| 112 | + state = env.reset() |
| 113 | + state = torch.tensor(state, dtype=torch.float32) |
| 114 | + log_probs = [] |
| 115 | + rewards = [] |
| 116 | + # Generate an episode |
| 117 | + done = False |
| 118 | + while not done: |
| 119 | + # Get action probabilities from policy |
| 120 | + logits = policy(state) |
| 121 | + dist = torch.distributions.Categorical(logits=logits) # categorical distribution over actions |
| 122 | + action = dist.sample() # sample an action index |
| 123 | + log_prob = dist.log_prob(action) # log π(a|s) |
| 124 | + next_state, reward, done, info = env.step(int(action.item())) |
| 125 | + |
| 126 | + # Record the log-prob and reward |
| 127 | + log_probs.append(log_prob) |
| 128 | + rewards.append(reward) |
| 129 | + |
| 130 | + # Move to next state |
| 131 | + if next_state is not None: |
| 132 | + state = torch.tensor(next_state, dtype=torch.float32) |
| 133 | + # Episode ended. Compute returns and update policy. |
| 134 | + # Calculate discounted returns for each step (here gamma=1, so it's just cumulative future reward from that step) |
| 135 | + returns = [] |
| 136 | + R = 0.0 |
| 137 | + for r in reversed(rewards): |
| 138 | + R = r + gamma * R |
| 139 | + returns.insert(0, R) |
| 140 | + returns = torch.tensor(returns, dtype=torch.float32) |
| 141 | + # Optionally normalize returns for stability |
| 142 | + returns = (returns - returns.mean()) / (returns.std() + 1e-8) |
| 143 | + |
| 144 | + # Policy gradient: maximize E[return * log_prob] -> minimize -(return * log_prob) |
| 145 | + loss = 0.0 |
| 146 | + for log_prob, R in zip(log_probs, returns): |
| 147 | + loss += -log_prob * R |
| 148 | + optimizer.zero_grad() |
| 149 | + loss.backward() |
| 150 | + optimizer.step() |
| 151 | + |
| 152 | + # (Optional) logging |
| 153 | + if episode % 500 == 0: |
| 154 | + total_reward = sum(rewards) |
| 155 | + print(f"Episode {episode}: total reward = {total_reward:.3f}") |
| 156 | + |
| 157 | +# After training, evaluate the policy |
| 158 | +test_episodes = 1000 |
| 159 | +correct_count = 0 |
| 160 | +for _ in range(test_episodes): |
| 161 | + state = env.reset() |
| 162 | + state = torch.tensor(state, dtype=torch.float32) |
| 163 | + done = False |
| 164 | + while not done: |
| 165 | + logits = policy(state) |
| 166 | + action = torch.argmax(logits).item() # choose the action with highest probability (greedy) |
| 167 | + next_state, reward, done, info = env.step(action) |
| 168 | + if next_state is not None: |
| 169 | + state = torch.tensor(next_state, dtype=torch.float32) |
| 170 | + # After episode, check if final estimate was correct |
| 171 | + est_idx = int(np.argmax(env.belief)) |
| 172 | + if est_idx == env.true_idx: |
| 173 | + correct_count += 1 |
| 174 | +accuracy = correct_count / test_episodes |
| 175 | +print(f"Policy accuracy over {test_episodes} test episodes: {accuracy*100:.1f}%") |
| 176 | + |
0 commit comments