1515use core:: ops:: { Deref , DerefMut } ;
1616use core:: sync:: atomic:: { AtomicU64 , Ordering } ;
1717use core:: time:: Duration ;
18- use std:: collections:: HashSet ;
18+ use std:: collections:: { HashMap , HashSet } ;
1919use std:: sync:: Arc ;
2020use std:: time:: { Instant , UNIX_EPOCH } ;
2121
22- use async_lock:: Mutex ;
22+ use async_lock:: RwLock ;
2323use lru:: LruCache ;
2424use nativelink_config:: schedulers:: WorkerAllocationStrategy ;
25- use nativelink_error:: { Code , Error , ResultExt , error_if , make_err , make_input_err } ;
25+ use nativelink_error:: { error_if , make_err , make_input_err , Code , Error , ResultExt } ;
2626use nativelink_metric:: {
27- MetricFieldData , MetricKind , MetricPublishKnownKindData , MetricsComponent ,
28- RootMetricsComponent , group ,
27+ group , MetricFieldData , MetricKind , MetricPublishKnownKindData ,
28+ MetricsComponent , RootMetricsComponent ,
2929} ;
3030use nativelink_util:: action_messages:: { OperationId , WorkerId } ;
3131use nativelink_util:: metrics:: { WORKER_POOL_METRICS , WORKER_POOL_INSTANCE , WorkerPoolMetricAttrs } ;
@@ -547,6 +547,76 @@ impl ApiWorkerSchedulerImpl {
547547 }
548548 }
549549
550+ /// Batch notifies multiple workers to run actions in a single lock hold.
551+ /// Returns a vector of results for each notification attempt.
552+ async fn inner_batch_worker_notify_run_action (
553+ & mut self ,
554+ assignments : Vec < ( WorkerId , OperationId , ActionInfoWithProps ) > ,
555+ ) -> Vec < Result < ( ) , Error > > {
556+ let mut results = Vec :: with_capacity ( assignments. len ( ) ) ;
557+ let mut workers_to_evict: Vec < ( WorkerId , Error , bool ) > = Vec :: new ( ) ;
558+
559+ for ( worker_id, operation_id, action_info) in assignments {
560+ if let Some ( worker) = self . workers . get_mut ( & worker_id) {
561+ let notify_worker_result = worker
562+ . notify_update ( WorkerUpdate :: RunAction ( (
563+ operation_id. clone ( ) ,
564+ action_info. clone ( ) ,
565+ ) ) )
566+ . await ;
567+
568+ if let Err ( notify_err) = notify_worker_result {
569+ warn ! (
570+ ?worker_id,
571+ ?action_info,
572+ ?notify_err,
573+ "Worker command failed in batch notify, will remove worker" ,
574+ ) ;
575+
576+ let is_disconnect = notify_err. code == Code :: Internal
577+ && notify_err. messages . len ( ) == 1
578+ && notify_err. messages [ 0 ] == "Worker Disconnected" ;
579+
580+ let err = make_err ! (
581+ Code :: Internal ,
582+ "Worker command failed, removing worker {worker_id} -- {notify_err:?}" ,
583+ ) ;
584+
585+ workers_to_evict. push ( ( worker_id. clone ( ) , err. clone ( ) , is_disconnect) ) ;
586+ results. push ( Err ( err) ) ;
587+ } else {
588+ results. push ( Ok ( ( ) ) ) ;
589+ }
590+ } else {
591+ warn ! (
592+ ?worker_id,
593+ %operation_id,
594+ ?action_info,
595+ "Worker not found in worker map in batch_worker_notify_run_action"
596+ ) ;
597+ // Queue the operation to be put back to queued state
598+ let update_result = self
599+ . worker_state_manager
600+ . update_operation (
601+ & operation_id,
602+ & worker_id,
603+ UpdateOperationType :: UpdateWithDisconnect ,
604+ )
605+ . await ;
606+ results. push ( update_result) ;
607+ }
608+ }
609+
610+ // Evict failed workers after processing all notifications
611+ for ( worker_id, err, is_disconnect) in workers_to_evict {
612+ let _ = self
613+ . immediate_evict_worker ( & worker_id, err, is_disconnect)
614+ . await ;
615+ }
616+
617+ results
618+ }
619+
550620 /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it.
551621 async fn immediate_evict_worker (
552622 & mut self ,
@@ -585,7 +655,7 @@ impl ApiWorkerSchedulerImpl {
585655#[ derive( Debug , MetricsComponent ) ]
586656pub struct ApiWorkerScheduler {
587657 #[ metric]
588- inner : Mutex < ApiWorkerSchedulerImpl > ,
658+ inner : RwLock < ApiWorkerSchedulerImpl > ,
589659 #[ metric( group = "platform_property_manager" ) ]
590660 platform_property_manager : Arc < PlatformPropertyManager > ,
591661
@@ -614,7 +684,7 @@ impl ApiWorkerScheduler {
614684 instance_name : impl Into < String > ,
615685 ) -> Arc < Self > {
616686 Arc :: new ( Self {
617- inner : Mutex :: new ( ApiWorkerSchedulerImpl {
687+ inner : RwLock :: new ( ApiWorkerSchedulerImpl {
618688 workers : Workers ( LruCache :: unbounded ( ) ) ,
619689 worker_state_manager : worker_state_manager. clone ( ) ,
620690 allocation_strategy,
@@ -651,7 +721,7 @@ impl ApiWorkerScheduler {
651721 self . metrics
652722 . actions_dispatched
653723 . fetch_add ( 1 , Ordering :: Relaxed ) ;
654- let mut inner = self . inner . lock ( ) . await ;
724+ let mut inner = self . inner . write ( ) . await ;
655725 let result = inner
656726 . worker_notify_run_action ( worker_id, operation_id, action_info)
657727 . await ;
@@ -667,6 +737,39 @@ impl ApiWorkerScheduler {
667737 result
668738 }
669739
740+ /// Batch notifies multiple workers to run actions in a single lock acquisition.
741+ /// This reduces lock contention compared to calling `worker_notify_run_action`
742+ /// for each action individually.
743+ ///
744+ /// Returns a vector of results corresponding to each assignment in the input.
745+ pub async fn batch_worker_notify_run_action (
746+ & self ,
747+ assignments : Vec < ( WorkerId , OperationId , ActionInfoWithProps ) > ,
748+ ) -> Vec < Result < ( ) , Error > > {
749+ let count = assignments. len ( ) ;
750+ self . metrics
751+ . actions_dispatched
752+ . fetch_add ( count as u64 , Ordering :: Relaxed ) ;
753+
754+ let mut inner = self . inner . write ( ) . await ;
755+ let results = inner. inner_batch_worker_notify_run_action ( assignments) . await ;
756+
757+ // Record metrics
758+ let successes = results. iter ( ) . filter ( |r| r. is_ok ( ) ) . count ( ) ;
759+ let failures = count - successes;
760+
761+ for _ in 0 ..successes {
762+ self . worker_scheduler_metrics . record_action_dispatched ( ) ;
763+ }
764+ for _ in 0 ..failures {
765+ self . worker_scheduler_metrics . record_dispatch_failure ( ) ;
766+ }
767+ self . worker_scheduler_metrics
768+ . record_running_actions_count ( inner. count_running_actions ( ) ) ;
769+
770+ results
771+ }
772+
670773 /// Returns the scheduler metrics for observability.
671774 #[ must_use]
672775 pub const fn get_metrics ( & self ) -> & Arc < SchedulerMetrics > {
@@ -687,7 +790,7 @@ impl ApiWorkerScheduler {
687790 . find_worker_calls
688791 . fetch_add ( 1 , Ordering :: Relaxed ) ;
689792
690- let inner = self . inner . lock ( ) . await ;
793+ let inner = self . inner . read ( ) . await ;
691794 let worker_count = inner. workers . len ( ) as u64 ;
692795 let result = inner. inner_find_worker_for_action ( platform_properties, full_worker_logging) ;
693796
@@ -723,13 +826,13 @@ impl ApiWorkerScheduler {
723826 & self ,
724827 actions : & [ & PlatformProperties ] ,
725828 full_worker_logging : bool ,
726- ) -> Vec < ( usize , WorkerId ) > {
829+ ) -> HashMap < usize , WorkerId > {
727830 let start = Instant :: now ( ) ;
728831 self . metrics
729832 . find_worker_calls
730833 . fetch_add ( actions. len ( ) as u64 , Ordering :: Relaxed ) ;
731834
732- let inner = self . inner . lock ( ) . await ;
835+ let inner = self . inner . read ( ) . await ;
733836 let worker_count = inner. workers . len ( ) as u64 ;
734837 let results =
735838 inner. inner_batch_find_workers_for_actions ( actions, full_worker_logging) ;
@@ -759,7 +862,7 @@ impl ApiWorkerScheduler {
759862 /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests.
760863 #[ must_use]
761864 pub async fn contains_worker_for_test ( & self , worker_id : & WorkerId ) -> bool {
762- let inner = self . inner . lock ( ) . await ;
865+ let inner = self . inner . read ( ) . await ;
763866 inner. workers . contains ( worker_id)
764867 }
765868
@@ -768,15 +871,15 @@ impl ApiWorkerScheduler {
768871 & self ,
769872 worker_id : & WorkerId ,
770873 ) -> Result < ( ) , Error > {
771- let mut inner = self . inner . lock ( ) . await ;
874+ let mut inner = self . inner . write ( ) . await ;
772875 let worker = inner. workers . get_mut ( worker_id) . ok_or_else ( || {
773876 make_input_err ! ( "WorkerId '{}' does not exist in workers map" , worker_id)
774877 } ) ?;
775878 worker. keep_alive ( )
776879 }
777880
778881 pub async fn get_workers_state ( & self ) -> Vec < WorkerState > {
779- let inner = self . inner . lock ( ) . await ;
882+ let inner = self . inner . read ( ) . await ;
780883 inner. workers . iter ( ) . map ( |( _, w) | w. to_state ( ) ) . collect ( )
781884 }
782885}
@@ -790,7 +893,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
790893 async fn add_worker ( & self , worker : Worker ) -> Result < ( ) , Error > {
791894 let worker_id = worker. id . clone ( ) ;
792895 let worker_timestamp = worker. last_update_timestamp ;
793- let mut inner = self . inner . lock ( ) . await ;
896+ let mut inner = self . inner . write ( ) . await ;
794897 if inner. shutting_down {
795898 warn ! ( "Rejected worker add during shutdown: {}" , worker_id) ;
796899 return Err ( make_err ! (
@@ -833,7 +936,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
833936 UpdateOperationType :: UpdateWithError ( _) | UpdateOperationType :: UpdateWithDisconnect
834937 ) ;
835938
836- let mut inner = self . inner . lock ( ) . await ;
939+ let mut inner = self . inner . write ( ) . await ;
837940 let result = inner. update_action ( worker_id, operation_id, update) . await ;
838941
839942 // Record action completion metric
@@ -851,7 +954,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
851954 timestamp : WorkerTimestamp ,
852955 ) -> Result < ( ) , Error > {
853956 {
854- let mut inner = self . inner . lock ( ) . await ;
957+ let mut inner = self . inner . write ( ) . await ;
855958 inner
856959 . refresh_lifetime ( worker_id, timestamp)
857960 . err_tip ( || "Error refreshing lifetime in worker_keep_alive_received()" ) ?;
@@ -866,7 +969,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
866969 async fn remove_worker ( & self , worker_id : & WorkerId ) -> Result < ( ) , Error > {
867970 self . worker_registry . remove_worker ( worker_id) . await ;
868971
869- let mut inner = self . inner . lock ( ) . await ;
972+ let mut inner = self . inner . write ( ) . await ;
870973 let result = inner
871974 . immediate_evict_worker (
872975 worker_id,
@@ -882,7 +985,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
882985 }
883986
884987 async fn shutdown ( & self , shutdown_guard : ShutdownGuard ) {
885- let mut inner = self . inner . lock ( ) . await ;
988+ let mut inner = self . inner . write ( ) . await ;
886989 inner. shutting_down = true ; // should reject further worker registration
887990 while let Some ( worker_id) = inner
888991 . workers
@@ -910,8 +1013,9 @@ impl WorkerScheduler for ApiWorkerScheduler {
9101013 let now = UNIX_EPOCH + Duration :: from_secs ( now_timestamp) ;
9111014 let timeout_threshold = now_timestamp. saturating_sub ( self . worker_timeout_s ) ;
9121015
1016+ // Phase 1: Read-only collection of workers to check
9131017 let workers_to_check: Vec < ( WorkerId , bool ) > = {
914- let inner = self . inner . lock ( ) . await ;
1018+ let inner = self . inner . read ( ) . await ;
9151019 inner
9161020 . workers
9171021 . iter ( )
@@ -949,7 +1053,8 @@ impl WorkerScheduler for ApiWorkerScheduler {
9491053 return Ok ( ( ) ) ;
9501054 }
9511055
952- let mut inner = self . inner . lock ( ) . await ;
1056+ // Phase 2: Write lock to remove timed out workers
1057+ let mut inner = self . inner . write ( ) . await ;
9531058 let mut result = Ok ( ( ) ) ;
9541059
9551060 for worker_id in & worker_ids_to_remove {
@@ -976,7 +1081,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
9761081 }
9771082
9781083 async fn set_drain_worker ( & self , worker_id : & WorkerId , is_draining : bool ) -> Result < ( ) , Error > {
979- let mut inner = self . inner . lock ( ) . await ;
1084+ let mut inner = self . inner . write ( ) . await ;
9801085 inner. set_drain_worker ( worker_id, is_draining) . await ?;
9811086 self . worker_scheduler_metrics . record_worker_count ( inner. workers . len ( ) ) ;
9821087 Ok ( ( ) )
0 commit comments