Skip to content

Commit 22197b3

Browse files
committed
Merge pull request #19 from HugoFara/fix/blockunit-subinfo-guards
2 parents 3e0d2a1 + b96d036 commit 22197b3

4 files changed

Lines changed: 198 additions & 3 deletions

File tree

psyflow/BlockUnit.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ def run_trial(self, func: Callable, **kwargs) -> "BlockUnit":
265265
**kwargs : dict
266266
Additional keyword arguments forwarded to ``func``.
267267
"""
268+
if self.conditions is None:
269+
raise RuntimeError(
270+
f"BlockUnit '{self.block_id}' has no conditions. "
271+
"Call generate_conditions() before run_trial()."
272+
)
273+
268274
self.meta['block_start_time'] = core.getAbsTime()
269275
self.logging_block_info()
270276

@@ -273,6 +279,16 @@ def run_trial(self, func: Callable, **kwargs) -> "BlockUnit":
273279

274280
for i, cond in enumerate(self.conditions):
275281
result = func(self.win, self.kb, self.settings, cond, **kwargs)
282+
if not isinstance(result, dict):
283+
func_name = getattr(func, "__name__", None)
284+
if func_name is None and hasattr(func, "func"):
285+
func_name = getattr(func.func, "__name__", None)
286+
if func_name is None:
287+
func_name = type(func).__name__
288+
raise TypeError(
289+
f"Trial function {func_name!r} must return a dict, "
290+
f"got {type(result).__name__!r}"
291+
)
276292
result.update({
277293
"trial_index": i,
278294
"block_id": self.block_id,
@@ -403,10 +419,14 @@ def logging_block_info(self) -> None:
403419
"""
404420
Log block metadata including ID, index, seed, trial count, and condition distribution.
405421
"""
406-
dist = {c: self.conditions.count(c) for c in set(self.conditions)} if self.conditions else {}
422+
if self.conditions is not None and len(self.conditions) > 0:
423+
conds = np.asarray(self.conditions, dtype=object)
424+
dist = {c: int(np.sum(conds == c)) for c in set(self.conditions)}
425+
else:
426+
dist = {}
407427
logging.data(f"[BlockUnit] Blockid: {self.block_id}")
408428
logging.data(f"[BlockUnit] Blockidx: {self.block_idx}")
409429
logging.data(f"[BlockUnit] Blockseed: {self.seed}")
410-
logging.data(f"[BlockUnit] Blocktrial-N: {len(self.conditions)}")
430+
logging.data(f"[BlockUnit] Blocktrial-N: {len(self.conditions) if self.conditions is not None else 0}")
411431
logging.data(f"[BlockUnit] Blockdist: {dist}")
412432
logging.data(f"[BlockUnit] Blockconditions: {self.conditions}")

psyflow/SubInfo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def validate(self, responses) -> bool:
167167
raise ValueError
168168
if digits is not None and len(str(val)) != digits:
169169
raise ValueError
170-
except:
170+
except Exception:
171171
infoDlg = gui.Dlg()
172172
infoDlg.addText(
173173
self._local("invalid_input").format(field=self._local(field['name']))

tests/test_BlockUnit.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Tests for psyflow.BlockUnit."""
2+
3+
from functools import partial
4+
import unittest
5+
from unittest.mock import MagicMock, patch
6+
from types import SimpleNamespace
7+
8+
try:
9+
import numpy # noqa: F401
10+
from psychopy import core, logging # noqa: F401
11+
_HAS_DEPS = True
12+
except ImportError:
13+
_HAS_DEPS = False
14+
15+
if _HAS_DEPS:
16+
from psyflow.BlockUnit import BlockUnit
17+
18+
19+
def _make_settings(**overrides):
20+
"""Create a minimal settings-like object."""
21+
defaults = {
22+
"trials_per_block": 3,
23+
"block_seed": [42],
24+
}
25+
defaults.update(overrides)
26+
return SimpleNamespace(**defaults)
27+
28+
29+
def _make_block(**overrides):
30+
"""Build a BlockUnit without calling __init__ (avoids PsychoPy window)."""
31+
block = BlockUnit.__new__(BlockUnit)
32+
defaults = dict(
33+
block_id="test",
34+
block_idx=0,
35+
n_trials=3,
36+
settings=_make_settings(),
37+
win=MagicMock(),
38+
kb=MagicMock(),
39+
seed=42,
40+
conditions=None,
41+
results=[],
42+
meta={},
43+
_on_start=[],
44+
_on_end=[],
45+
)
46+
defaults.update(overrides)
47+
for k, v in defaults.items():
48+
setattr(block, k, v)
49+
return block
50+
51+
52+
@unittest.skipUnless(_HAS_DEPS, "requires numpy and psychopy")
53+
class TestRunTrialGuards(unittest.TestCase):
54+
"""run_trial() should reject invalid state with clear errors."""
55+
56+
def test_conditions_none_raises_runtime_error(self):
57+
block = _make_block(conditions=None)
58+
59+
with self.assertRaises(RuntimeError) as ctx:
60+
block.run_trial(lambda win, kb, s, c: {"rt": 0.5})
61+
62+
self.assertIn("conditions", str(ctx.exception).lower())
63+
64+
def test_func_returning_none_raises_type_error(self):
65+
block = _make_block(conditions=["A"])
66+
67+
def bad_trial_func(win, kb, settings, cond):
68+
return None
69+
70+
with self.assertRaises(TypeError) as ctx:
71+
block.run_trial(bad_trial_func)
72+
73+
self.assertIn("dict", str(ctx.exception).lower())
74+
75+
def test_partial_trial_func_raises_type_error(self):
76+
block = _make_block(conditions=["A"])
77+
78+
def bad_trial_func(win, kb, settings, cond):
79+
return None
80+
81+
with self.assertRaises(TypeError) as ctx:
82+
block.run_trial(partial(bad_trial_func))
83+
84+
self.assertIn("bad_trial_func", str(ctx.exception))
85+
86+
87+
@unittest.skipUnless(_HAS_DEPS, "requires numpy and psychopy")
88+
class TestLoggingBlockInfo(unittest.TestCase):
89+
"""logging_block_info() should handle list-backed conditions."""
90+
91+
def test_counts_python_list_conditions(self):
92+
block = _make_block(conditions=["A", "B", "A"])
93+
94+
with patch("psyflow.BlockUnit.logging.data") as mock_log:
95+
block.logging_block_info()
96+
97+
messages = [call.args[0] for call in mock_log.call_args_list]
98+
self.assertTrue(
99+
any(
100+
"Blockdist:" in msg and "'A': 2" in msg and "'B': 1" in msg
101+
for msg in messages
102+
),
103+
msg=f"Unexpected log messages: {messages!r}",
104+
)
105+
106+
107+
if __name__ == "__main__":
108+
unittest.main()

tests/test_SubInfo.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Tests for psyflow.SubInfo."""
2+
3+
import unittest
4+
from unittest.mock import MagicMock, patch
5+
6+
try:
7+
from psychopy import gui # noqa: F401
8+
_HAS_PSYCHOPY = True
9+
except ImportError:
10+
_HAS_PSYCHOPY = False
11+
12+
if _HAS_PSYCHOPY:
13+
import psyflow.SubInfo as _subinfo_mod
14+
from psyflow.SubInfo import SubInfo
15+
_gui = _subinfo_mod.gui
16+
17+
18+
def _make_subinfo(**field_map_overrides):
19+
"""Build a SubInfo without calling __init__."""
20+
info = SubInfo.__new__(SubInfo)
21+
info.fields = [
22+
{"name": "subject_id", "type": "int",
23+
"constraints": {"min": 101, "max": 999, "digits": 3}}
24+
]
25+
info.field_map = {
26+
"Participant Information": "Info",
27+
"registration_failed": "Failed",
28+
"invalid_input": "Bad: {field}",
29+
}
30+
info.field_map.update(field_map_overrides)
31+
info.subject_data = None
32+
return info
33+
34+
35+
@unittest.skipUnless(_HAS_PSYCHOPY, "requires psychopy")
36+
class TestCollect(unittest.TestCase):
37+
"""SubInfo.collect() control-flow edge cases."""
38+
39+
def test_cancel_returns_none(self):
40+
info = _make_subinfo()
41+
42+
mock_dlg = MagicMock()
43+
mock_dlg.show.return_value = None
44+
with patch.object(_gui, "Dlg", return_value=mock_dlg):
45+
result = info.collect(exit_on_cancel=False)
46+
self.assertIsNone(result)
47+
48+
49+
@unittest.skipUnless(_HAS_PSYCHOPY, "requires psychopy")
50+
class TestValidate(unittest.TestCase):
51+
"""SubInfo.validate() error handling."""
52+
53+
def test_keyboard_interrupt_propagates(self):
54+
info = _make_subinfo()
55+
56+
class ExplodingStr:
57+
def __int__(self):
58+
raise KeyboardInterrupt("simulated Ctrl+C")
59+
def __str__(self):
60+
return "boom"
61+
62+
with self.assertRaises(KeyboardInterrupt):
63+
info.validate([ExplodingStr()])
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()

0 commit comments

Comments
 (0)