Skip to content

Commit 2e15db8

Browse files
committed
Refactor state transition logic
1 parent 41ad757 commit 2e15db8

1 file changed

Lines changed: 64 additions & 97 deletions

File tree

pybpodapi/bpod/bpod_base.py

Lines changed: 64 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,46 @@ def __initialize_input_command_handler(self):
497497
self.stdin = NonBlockingStreamReader(
498498
sys.stdin) if settings.PYBPOD_API_ACCEPT_STDIN else None
499499

500+
def __transition_to_new_state(self, sma, event_id, transition_matrix,
501+
current_trial, state_change_indexes,
502+
is_state_timer_matrix=False,
503+
debug_message=None):
504+
new_state_set = False
505+
506+
def set_sma_current_state(new_state):
507+
previous_state = sma.current_state
508+
if sma.use_255_back_signal and new_state == 255:
509+
sma.current_state = current_trial.states[-2]
510+
else:
511+
sma.current_state = new_state
512+
logger.debug(('Transition occured: '
513+
f'state {previous_state} -> {sma.current_state}'))
514+
if not math.isnan(sma.current_state):
515+
if debug_message is not None:
516+
logger.debug(debug_message)
517+
current_trial.states.append(sma.current_state)
518+
state_change_indexes.append(
519+
len(current_trial.events_occurrences) - 1)
520+
current_state = sma.current_state
521+
if is_state_timer_matrix:
522+
this_state_timer_state = transition_matrix[current_state]
523+
is_event_id_tup = event_id == sma.hardware.channels.events_positions.Tup
524+
if is_event_id_tup and this_state_timer_state != current_state:
525+
set_sma_current_state(this_state_timer_state)
526+
new_state_set = True
527+
else:
528+
for transition_event_code, transition_state in transition_matrix[
529+
current_state]:
530+
if transition_event_code == event_id:
531+
set_sma_current_state(transition_state)
532+
new_state_set = True
533+
else:
534+
logger.debug((f'Event {transition_event_code} required '
535+
f'for transition: state '
536+
f'{sma.current_state} -> '
537+
f'{transition_state}'))
538+
return new_state_set
539+
500540
def __process_opcode(self, sma, opcode, data, state_change_indexes):
501541
"""
502542
Process data from bpod board given an opcode
@@ -538,103 +578,30 @@ def __process_opcode(self, sma, opcode, data, state_change_indexes):
538578
)
539579
self.trial_timestamps.append(event_timestamp)
540580

541-
# input matrix
542-
if not transition_event_found:
543-
logger.debug("transition event not found")
544-
logger.debug("Current state: %s", sma.current_state)
545-
for transition in sma.input_matrix[sma.current_state]:
546-
logger.debug("Transition: %s", transition)
547-
if transition[0] == event_id:
548-
if sma.use_255_back_signal and transition[1] == 255:
549-
sma.current_state = current_trial.states[-2]
550-
else:
551-
sma.current_state = transition[1]
552-
553-
if not math.isnan(sma.current_state):
554-
logger.debug("adding states input matrix")
555-
current_trial.states.append(sma.current_state)
556-
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
557-
558-
transition_event_found = True
559-
560-
# state timer matrix
561-
if not transition_event_found:
562-
this_state_timer_transition = sma.state_timer_matrix[sma.current_state]
563-
if event_id == sma.hardware.channels.events_positions.Tup:
564-
if not (this_state_timer_transition == sma.current_state):
565-
if sma.use_255_back_signal and this_state_timer_transition == 255:
566-
sma.current_state = current_trial.states[-2]
567-
else:
568-
sma.current_state = this_state_timer_transition
569-
570-
if not math.isnan(sma.current_state):
571-
logger.debug("adding states state timer matrix")
572-
current_trial.states.append(sma.current_state)
573-
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
574-
transition_event_found = True
575-
576-
# global timers start matrix
577-
if not transition_event_found:
578-
for transition in sma.global_timers.start_matrix[sma.current_state]:
579-
if transition[0] == event_id:
580-
if sma.use_255_back_signal and transition[1] == 255:
581-
sma.current_state = current_trial.states[-2]
582-
else:
583-
sma.current_state = transition[1]
584-
585-
if not math.isnan(sma.current_state):
586-
logger.debug("adding states global timers start matrix")
587-
current_trial.states.append(sma.current_state)
588-
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
589-
transition_event_found = True
590-
591-
# global timers end matrix
592-
if not transition_event_found:
593-
for transition in sma.global_timers.end_matrix[sma.current_state]:
594-
if transition[0] == event_id:
595-
596-
if sma.use_255_back_signal and transition[1] == 255:
597-
sma.current_state = current_trial.states[-2]
598-
else:
599-
sma.current_state = transition[1]
600-
601-
if not math.isnan(sma.current_state):
602-
logger.debug("adding states global timers end matrix")
603-
current_trial.states.append(sma.current_state)
604-
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
605-
transition_event_found = True
606-
607-
# global counters matrix
608-
if not transition_event_found:
609-
for transition in sma.global_counters.matrix[sma.current_state]:
610-
if transition[0] == event_id:
611-
612-
if sma.use_255_back_signal and transition[1] == 255:
613-
sma.current_state = current_trial.states[-2]
614-
else:
615-
sma.current_state = transition[1]
616-
617-
if not math.isnan(sma.current_state):
618-
logger.debug("adding states global timers end matrix")
619-
current_trial.states.append(sma.current_state)
620-
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
621-
transition_event_found = True
622-
623-
# conditions matrix
624-
if not transition_event_found:
625-
for transition in sma.conditions.matrix[sma.current_state]:
626-
if transition[0] == event_id:
627-
628-
if sma.use_255_back_signal and transition[1] == 255:
629-
sma.current_state = current_trial.states[-2]
630-
else:
631-
sma.current_state = transition[1]
632-
633-
if not math.isnan(sma.current_state):
634-
logger.debug("adding states global timers end matrix")
635-
current_trial.states.append(sma.current_state)
636-
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
637-
transition_event_found = True
581+
logger.debug("Current state: %s", sma.current_state)
582+
transition_matrices = {
583+
'input': sma.input_matrix,
584+
'state_timer': sma.state_timer_matrix,
585+
'global_timers_start': sma.global_timers.start_matrix,
586+
'global_timers_end': sma.global_timers.end_matrix,
587+
'global_counters': sma.global_counters.matrix,
588+
'conditions': sma.conditions.matrix,
589+
}
590+
for transition_matrix_name, transition_matrix in \
591+
transition_matrices.items():
592+
transition_event_found = \
593+
self.__transition_to_new_state(
594+
sma,
595+
event_id,
596+
transition_matrix,
597+
current_trial,
598+
state_change_indexes,
599+
is_state_timer_matrix=(transition_matrix_name ==
600+
'state_timer'),
601+
debug_message="Adding {} matrix states".format(
602+
transition_matrix_name))
603+
if transition_event_found:
604+
break
638605

639606
logger.debug("States indexes: %s", current_trial.states)
640607
if self._emulator is not None:

0 commit comments

Comments
 (0)