Skip to content

Commit e49becf

Browse files
committed
Refactor integration tests to use a whitenoise source (from ezmsg-simbiophys) instead of counter, which has changed substantially.
1 parent 85e8dbb commit e49becf

7 files changed

Lines changed: 54 additions & 58 deletions

File tree

tests/integration/conftest.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import ezmsg.core as ez
2+
from ezmsg.baseproc import Clock, ClockSettings
3+
from ezmsg.simbiophys.noise import WhiteNoise, WhiteNoiseSettings
4+
from ezmsg.util.messages.axisarray import AxisArray
5+
6+
7+
class NoiseSrcSettings(ez.Settings):
8+
fs: float = 10.0
9+
n_time: int = 4
10+
n_ch: int = 1
11+
dispatch_rate: float | None = None
12+
13+
14+
class NoiseSrc(ez.Collection):
15+
"""Self-contained multi-channel signal source for integration tests."""
16+
17+
SETTINGS = NoiseSrcSettings
18+
19+
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
20+
21+
CLOCK = Clock()
22+
NOISE = WhiteNoise()
23+
24+
def configure(self) -> None:
25+
dispatch_rate = self.SETTINGS.dispatch_rate or (self.SETTINGS.fs / self.SETTINGS.n_time)
26+
self.CLOCK.apply_settings(ClockSettings(dispatch_rate=dispatch_rate))
27+
self.NOISE.apply_settings(
28+
WhiteNoiseSettings(
29+
fs=self.SETTINGS.fs,
30+
n_time=self.SETTINGS.n_time,
31+
n_ch=self.SETTINGS.n_ch,
32+
)
33+
)
34+
35+
def network(self) -> ez.NetworkDefinition:
36+
return (
37+
(self.CLOCK.OUTPUT_SIGNAL, self.NOISE.INPUT_CLOCK),
38+
(self.NOISE.OUTPUT_SIGNAL, self.OUTPUT_SIGNAL),
39+
)

tests/integration/test_mlp_system.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from pathlib import Path
44

55
import ezmsg.core as ez
6-
from ezmsg.simbiophys.counter import Counter, CounterSettings
76
from ezmsg.util.messagecodec import message_log
87
from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
98
from ezmsg.util.messages.axisarray import AxisArray
109
from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings
1110

1211
from ezmsg.learn.process.torch import TorchModelUnit
12+
from tests.integration.conftest import NoiseSrc, NoiseSrcSettings
1313

1414

1515
def test_torch_model_unit_system():
@@ -26,15 +26,7 @@ def test_torch_model_unit_system():
2626
ez.logger.info(f"Logging to {test_filename}")
2727

2828
comps = {
29-
"SRC": Counter(
30-
CounterSettings(
31-
fs=fs,
32-
n_ch=input_size,
33-
n_time=block_size,
34-
dispatch_rate=duration,
35-
mod=None,
36-
)
37-
),
29+
"SRC": NoiseSrc(NoiseSrcSettings(fs=fs, n_ch=input_size, n_time=block_size, dispatch_rate=duration)),
3830
"MODEL": TorchModelUnit(
3931
model_class="ezmsg.learn.model.mlp.MLP",
4032
model_kwargs={

tests/integration/test_refit_kalman_system.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import ezmsg.core as ez
77
import numpy as np
8-
from ezmsg.simbiophys.counter import Counter, CounterSettings
98
from ezmsg.util.messagecodec import message_log
109
from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
1110
from ezmsg.util.messages.axisarray import AxisArray
@@ -20,10 +19,11 @@
2019
RefitKalmanFilterSettings,
2120
RefitKalmanFilterUnit,
2221
)
22+
from tests.integration.conftest import NoiseSrc, NoiseSrcSettings
2323

2424

2525
class RefitKalmanSystemSettings(ez.Settings):
26-
counter_settings: CounterSettings
26+
source_settings: NoiseSrcSettings
2727
unit_settings: RefitKalmanFilterSettings
2828
log_settings: MessageLoggerSettings
2929
terminate_total: TerminateOnTotalSettings
@@ -33,14 +33,14 @@ class RefitKalmanSystemSettings(ez.Settings):
3333
class RefitKalmanSystem(ez.Collection):
3434
SETTINGS = RefitKalmanSystemSettings
3535

36-
SOURCE = Counter()
36+
SOURCE = NoiseSrc()
3737
UNIT = RefitKalmanFilterUnit()
3838
LOG = MessageLogger()
3939
TERM_TOTAL = TerminateOnTotal()
4040
TERM_TIMEOUT = TerminateOnTimeout()
4141

4242
def configure(self) -> None:
43-
self.SOURCE.apply_settings(self.SETTINGS.counter_settings)
43+
self.SOURCE.apply_settings(self.SETTINGS.source_settings)
4444
self.UNIT.apply_settings(self.SETTINGS.unit_settings)
4545
self.LOG.apply_settings(self.SETTINGS.log_settings)
4646
self.TERM_TOTAL.apply_settings(self.SETTINGS.terminate_total)
@@ -98,12 +98,7 @@ def test_refit_kalman_system():
9898
checkpoint_file = f.name
9999

100100
settings = RefitKalmanSystemSettings(
101-
counter_settings=CounterSettings(
102-
fs=fs,
103-
n_ch=1,
104-
n_time=block_size,
105-
dispatch_rate=duration,
106-
),
101+
source_settings=NoiseSrcSettings(fs=fs, n_ch=1, n_time=block_size, dispatch_rate=duration),
107102
unit_settings=RefitKalmanFilterSettings(
108103
checkpoint_path=checkpoint_file,
109104
steady_state=True,

tests/integration/test_rnn_system.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from pathlib import Path
44

55
import ezmsg.core as ez
6-
from ezmsg.simbiophys.counter import Counter, CounterSettings
76
from ezmsg.util.messagecodec import message_log
87
from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
98
from ezmsg.util.messages.axisarray import AxisArray
109
from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings
1110

1211
from ezmsg.learn.process.rnn import RNNUnit
12+
from tests.integration.conftest import NoiseSrc, NoiseSrcSettings
1313

1414

1515
def test_torch_model_unit_system():
@@ -29,15 +29,7 @@ def test_torch_model_unit_system():
2929
ez.logger.info(f"Logging to {test_filename}")
3030

3131
comps = {
32-
"SRC": Counter(
33-
CounterSettings(
34-
fs=fs,
35-
n_ch=input_size,
36-
n_time=block_size,
37-
dispatch_rate=duration,
38-
mod=None,
39-
)
40-
),
32+
"SRC": NoiseSrc(NoiseSrcSettings(fs=fs, n_ch=input_size, n_time=block_size, dispatch_rate=duration)),
4133
"MODEL": RNNUnit(
4234
single_precision=single_precision,
4335
learning_rate=1e-2,

tests/integration/test_sklearn_system.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import ezmsg.core as ez
77
import numpy as np
88
import pandas as pd
9-
from ezmsg.simbiophys.counter import Counter
109
from ezmsg.util.messagecodec import message_log
1110
from ezmsg.util.messagelogger import MessageLogger
1211
from ezmsg.util.messages.axisarray import AxisArray
1312
from ezmsg.util.terminate import TerminateOnTotal
1413
from river.linear_model import LinearRegression
1514

1615
from ezmsg.learn.process.sklearn import SklearnModelUnit
16+
from tests.integration.conftest import NoiseSrc, NoiseSrcSettings
1717

1818

1919
def test_sklearn_model_unit_system():
@@ -43,13 +43,7 @@ def test_sklearn_model_unit_system():
4343
ez.logger.info(f"Logging to {test_filename}")
4444

4545
comps = {
46-
"SRC": Counter(
47-
fs=fs,
48-
n_ch=input_size,
49-
n_time=block_size,
50-
dispatch_rate=duration,
51-
mod=None,
52-
),
46+
"SRC": NoiseSrc(NoiseSrcSettings(fs=fs, n_ch=input_size, n_time=block_size, dispatch_rate=duration)),
5347
"MODEL": SklearnModelUnit(
5448
model_class="river.linear_model.LinearRegression",
5549
model_kwargs={},

tests/integration/test_torch_system.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from pathlib import Path
44

55
import ezmsg.core as ez
6-
from ezmsg.simbiophys.counter import Counter, CounterSettings
76
from ezmsg.util.messagecodec import message_log
87
from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
98
from ezmsg.util.messages.axisarray import AxisArray
109
from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings
1110

1211
from ezmsg.learn.process.torch import TorchModelUnit
12+
from tests.integration.conftest import NoiseSrc, NoiseSrcSettings
1313

1414

1515
def test_torch_model_unit_system():
@@ -26,15 +26,7 @@ def test_torch_model_unit_system():
2626
ez.logger.info(f"Logging to {test_filename}")
2727

2828
comps = {
29-
"SRC": Counter(
30-
CounterSettings(
31-
fs=fs,
32-
n_ch=input_size,
33-
n_time=block_size,
34-
dispatch_rate=duration,
35-
mod=None,
36-
)
37-
),
29+
"SRC": NoiseSrc(NoiseSrcSettings(fs=fs, n_ch=input_size, n_time=block_size, dispatch_rate=duration)),
3830
"MODEL": TorchModelUnit(
3931
model_class="tests.unit.test_torch.DummyModel",
4032
model_kwargs={

tests/integration/test_transformer_system.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from pathlib import Path
44

55
import ezmsg.core as ez
6-
from ezmsg.simbiophys.counter import Counter, CounterSettings
76
from ezmsg.util.messagecodec import message_log
87
from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings
98
from ezmsg.util.messages.axisarray import AxisArray
109
from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings
1110

1211
from ezmsg.learn.process.transformer import TransformerUnit
12+
from tests.integration.conftest import NoiseSrc, NoiseSrcSettings
1313

1414

1515
def test_torch_model_unit_system():
@@ -30,15 +30,7 @@ def test_torch_model_unit_system():
3030
ez.logger.info(f"Logging to {test_filename}")
3131

3232
comps = {
33-
"SRC": Counter(
34-
CounterSettings(
35-
fs=fs,
36-
n_ch=input_size,
37-
n_time=block_size,
38-
dispatch_rate=duration,
39-
mod=None,
40-
)
41-
),
33+
"SRC": NoiseSrc(NoiseSrcSettings(fs=fs, n_ch=input_size, n_time=block_size, dispatch_rate=duration)),
4234
"MODEL": TransformerUnit(
4335
single_precision=single_precision,
4436
learning_rate=1e-2,

0 commit comments

Comments
 (0)