Skip to content

Commit fb559f4

Browse files
committed
replace deprecated SampleMessage with AxisArray containing trigger in attrs.
1 parent 361e03d commit fb559f4

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

tests/unit/test_mlp_old.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,15 @@ def xy_gen(set: int = 0):
171171
result = []
172172
train_loss = []
173173
for sample_msg in xy_gen(set=0):
174-
# Naive closed-loop inference
175-
result.append(proc(sample_msg.sample))
174+
# Naive closed-loop inference — strip trigger attrs before inference
175+
plain_msg = replace(sample_msg, attrs={})
176+
result.append(proc(plain_msg))
176177

177178
# Collect the loss to see if it decreases with training.
178179
train_loss.append(
179180
torch.nn.MSELoss()(
180181
torch.tensor(result[-1].data),
181-
torch.tensor(sample_msg.trigger.value.reshape(-1, 1), dtype=torch.float32),
182+
torch.tensor(sample_msg.attrs["trigger"].value.reshape(-1, 1), dtype=torch.float32),
182183
).item()
183184
)
184185

0 commit comments

Comments
 (0)