@@ -28,13 +28,16 @@ def give_sim_work_first(W, H, sim_specs, gen_specs, alloc_specs, persis_info, li
2828 if libE_info ["sim_max_given" ] or not libE_info ["any_idle_workers" ]:
2929 return {}, persis_info
3030
31+ user = {** gen_specs , ** alloc_specs .get ("user" , {})}
3132 manage_resources = libE_info ["use_resource_sets" ]
3233 support = AllocSupport (W , manage_resources , persis_info , libE_info )
3334 Work = {}
3435 gen_count = support .count_gens ()
3536
3637 if gen_specs ["user" ].get ("single_component_at_a_time" ):
37- assert alloc_specs ["user" ]["batch_mode" ], "Must be in batch mode when using 'single_component_at_a_time'"
38+ assert alloc_specs ["user" ].get ("batch_mode" , False ) or gen_specs .get (
39+ "batch_mode" , False
40+ ), "Must be in batch mode when using 'single_component_at_a_time'"
3841 if len (H ) != persis_info ["H_len" ]:
3942 # Something new is in the history.
4043 persis_info ["need_to_give" ].update (H ["sim_id" ][persis_info ["H_len" ] :].tolist ())
@@ -119,13 +122,13 @@ def give_sim_work_first(W, H, sim_specs, gen_specs, alloc_specs, persis_info, li
119122 break
120123
121124 while len (idle_gen_workers ):
122- if gen_count < alloc_specs [ " user" ] .get ("num_active_gens" , gen_count + 1 ):
125+ if gen_count < user .get ("num_active_gens" , gen_count + 1 ):
123126 lw = persis_info ["last_worker" ]
124127
125128 last_size = persis_info .get ("last_size" )
126129 if len (H ):
127130 # Don't give gen instances in batch mode if points are unfinished
128- if alloc_specs ["user" ].get ("batch_mode" ) and not all (
131+ if ( alloc_specs ["user" ].get ("batch_mode" ) or gen_specs . get ( "batch_mode" ) ) and not all (
129132 np .logical_or (H ["sim_ended" ][last_size :], H ["paused" ][last_size :])
130133 ):
131134 break
@@ -142,7 +145,7 @@ def give_sim_work_first(W, H, sim_specs, gen_specs, alloc_specs, persis_info, li
142145 persis_info ["last_worker" ] = i
143146 persis_info ["last_size" ] = len (H )
144147
145- elif gen_count >= alloc_specs [ " user" ] .get ("num_active_gens" , gen_count + 1 ):
148+ elif gen_count >= user .get ("num_active_gens" , gen_count + 1 ):
146149 idle_gen_workers = []
147150
148151 return Work , persis_info
0 commit comments