Skip to content

Commit 12d839c

Browse files
committed
Introduce batch notify and assign actions.
1 parent 7393a00 commit 12d839c

1 file changed

Lines changed: 127 additions & 22 deletions

File tree

nativelink-scheduler/src/api_worker_scheduler.rs

Lines changed: 127 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
use core::ops::{Deref, DerefMut};
1616
use core::sync::atomic::{AtomicU64, Ordering};
1717
use core::time::Duration;
18-
use std::collections::HashSet;
18+
use std::collections::{HashMap, HashSet};
1919
use std::sync::Arc;
2020
use std::time::{Instant, UNIX_EPOCH};
2121

22-
use async_lock::Mutex;
22+
use async_lock::RwLock;
2323
use lru::LruCache;
2424
use nativelink_config::schedulers::WorkerAllocationStrategy;
25-
use nativelink_error::{Code, Error, ResultExt, error_if, make_err, make_input_err};
25+
use nativelink_error::{error_if, make_err, make_input_err, Code, Error, ResultExt};
2626
use nativelink_metric::{
27-
MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent,
28-
RootMetricsComponent, group,
27+
group, MetricFieldData, MetricKind, MetricPublishKnownKindData,
28+
MetricsComponent, RootMetricsComponent,
2929
};
3030
use nativelink_util::action_messages::{OperationId, WorkerId};
3131
use nativelink_util::metrics::{WORKER_POOL_METRICS, WORKER_POOL_INSTANCE, WorkerPoolMetricAttrs};
@@ -547,6 +547,76 @@ impl ApiWorkerSchedulerImpl {
547547
}
548548
}
549549

550+
/// Batch notifies multiple workers to run actions in a single lock hold.
551+
/// Returns a vector of results for each notification attempt.
552+
async fn inner_batch_worker_notify_run_action(
553+
&mut self,
554+
assignments: Vec<(WorkerId, OperationId, ActionInfoWithProps)>,
555+
) -> Vec<Result<(), Error>> {
556+
let mut results = Vec::with_capacity(assignments.len());
557+
let mut workers_to_evict: Vec<(WorkerId, Error, bool)> = Vec::new();
558+
559+
for (worker_id, operation_id, action_info) in assignments {
560+
if let Some(worker) = self.workers.get_mut(&worker_id) {
561+
let notify_worker_result = worker
562+
.notify_update(WorkerUpdate::RunAction((
563+
operation_id.clone(),
564+
action_info.clone(),
565+
)))
566+
.await;
567+
568+
if let Err(notify_err) = notify_worker_result {
569+
warn!(
570+
?worker_id,
571+
?action_info,
572+
?notify_err,
573+
"Worker command failed in batch notify, will remove worker",
574+
);
575+
576+
let is_disconnect = notify_err.code == Code::Internal
577+
&& notify_err.messages.len() == 1
578+
&& notify_err.messages[0] == "Worker Disconnected";
579+
580+
let err = make_err!(
581+
Code::Internal,
582+
"Worker command failed, removing worker {worker_id} -- {notify_err:?}",
583+
);
584+
585+
workers_to_evict.push((worker_id.clone(), err.clone(), is_disconnect));
586+
results.push(Err(err));
587+
} else {
588+
results.push(Ok(()));
589+
}
590+
} else {
591+
warn!(
592+
?worker_id,
593+
%operation_id,
594+
?action_info,
595+
"Worker not found in worker map in batch_worker_notify_run_action"
596+
);
597+
// Queue the operation to be put back to queued state
598+
let update_result = self
599+
.worker_state_manager
600+
.update_operation(
601+
&operation_id,
602+
&worker_id,
603+
UpdateOperationType::UpdateWithDisconnect,
604+
)
605+
.await;
606+
results.push(update_result);
607+
}
608+
}
609+
610+
// Evict failed workers after processing all notifications
611+
for (worker_id, err, is_disconnect) in workers_to_evict {
612+
let _ = self
613+
.immediate_evict_worker(&worker_id, err, is_disconnect)
614+
.await;
615+
}
616+
617+
results
618+
}
619+
550620
/// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it.
551621
async fn immediate_evict_worker(
552622
&mut self,
@@ -585,7 +655,7 @@ impl ApiWorkerSchedulerImpl {
585655
#[derive(Debug, MetricsComponent)]
586656
pub struct ApiWorkerScheduler {
587657
#[metric]
588-
inner: Mutex<ApiWorkerSchedulerImpl>,
658+
inner: RwLock<ApiWorkerSchedulerImpl>,
589659
#[metric(group = "platform_property_manager")]
590660
platform_property_manager: Arc<PlatformPropertyManager>,
591661

@@ -614,7 +684,7 @@ impl ApiWorkerScheduler {
614684
instance_name: impl Into<String>,
615685
) -> Arc<Self> {
616686
Arc::new(Self {
617-
inner: Mutex::new(ApiWorkerSchedulerImpl {
687+
inner: RwLock::new(ApiWorkerSchedulerImpl {
618688
workers: Workers(LruCache::unbounded()),
619689
worker_state_manager: worker_state_manager.clone(),
620690
allocation_strategy,
@@ -651,7 +721,7 @@ impl ApiWorkerScheduler {
651721
self.metrics
652722
.actions_dispatched
653723
.fetch_add(1, Ordering::Relaxed);
654-
let mut inner = self.inner.lock().await;
724+
let mut inner = self.inner.write().await;
655725
let result = inner
656726
.worker_notify_run_action(worker_id, operation_id, action_info)
657727
.await;
@@ -667,6 +737,39 @@ impl ApiWorkerScheduler {
667737
result
668738
}
669739

740+
/// Batch notifies multiple workers to run actions in a single lock acquisition.
741+
/// This reduces lock contention compared to calling `worker_notify_run_action`
742+
/// for each action individually.
743+
///
744+
/// Returns a vector of results corresponding to each assignment in the input.
745+
pub async fn batch_worker_notify_run_action(
746+
&self,
747+
assignments: Vec<(WorkerId, OperationId, ActionInfoWithProps)>,
748+
) -> Vec<Result<(), Error>> {
749+
let count = assignments.len();
750+
self.metrics
751+
.actions_dispatched
752+
.fetch_add(count as u64, Ordering::Relaxed);
753+
754+
let mut inner = self.inner.write().await;
755+
let results = inner.inner_batch_worker_notify_run_action(assignments).await;
756+
757+
// Record metrics
758+
let successes = results.iter().filter(|r| r.is_ok()).count();
759+
let failures = count - successes;
760+
761+
for _ in 0..successes {
762+
self.worker_scheduler_metrics.record_action_dispatched();
763+
}
764+
for _ in 0..failures {
765+
self.worker_scheduler_metrics.record_dispatch_failure();
766+
}
767+
self.worker_scheduler_metrics
768+
.record_running_actions_count(inner.count_running_actions());
769+
770+
results
771+
}
772+
670773
/// Returns the scheduler metrics for observability.
671774
#[must_use]
672775
pub const fn get_metrics(&self) -> &Arc<SchedulerMetrics> {
@@ -687,7 +790,7 @@ impl ApiWorkerScheduler {
687790
.find_worker_calls
688791
.fetch_add(1, Ordering::Relaxed);
689792

690-
let inner = self.inner.lock().await;
793+
let inner = self.inner.read().await;
691794
let worker_count = inner.workers.len() as u64;
692795
let result = inner.inner_find_worker_for_action(platform_properties, full_worker_logging);
693796

@@ -723,13 +826,13 @@ impl ApiWorkerScheduler {
723826
&self,
724827
actions: &[&PlatformProperties],
725828
full_worker_logging: bool,
726-
) -> Vec<(usize, WorkerId)> {
829+
) -> HashMap<usize, WorkerId> {
727830
let start = Instant::now();
728831
self.metrics
729832
.find_worker_calls
730833
.fetch_add(actions.len() as u64, Ordering::Relaxed);
731834

732-
let inner = self.inner.lock().await;
835+
let inner = self.inner.read().await;
733836
let worker_count = inner.workers.len() as u64;
734837
let results =
735838
inner.inner_batch_find_workers_for_actions(actions, full_worker_logging);
@@ -759,7 +862,7 @@ impl ApiWorkerScheduler {
759862
/// Checks to see if the worker exists in the worker pool. Should only be used in unit tests.
760863
#[must_use]
761864
pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool {
762-
let inner = self.inner.lock().await;
865+
let inner = self.inner.read().await;
763866
inner.workers.contains(worker_id)
764867
}
765868

@@ -768,15 +871,15 @@ impl ApiWorkerScheduler {
768871
&self,
769872
worker_id: &WorkerId,
770873
) -> Result<(), Error> {
771-
let mut inner = self.inner.lock().await;
874+
let mut inner = self.inner.write().await;
772875
let worker = inner.workers.get_mut(worker_id).ok_or_else(|| {
773876
make_input_err!("WorkerId '{}' does not exist in workers map", worker_id)
774877
})?;
775878
worker.keep_alive()
776879
}
777880

778881
pub async fn get_workers_state(&self) -> Vec<WorkerState> {
779-
let inner = self.inner.lock().await;
882+
let inner = self.inner.read().await;
780883
inner.workers.iter().map(|(_, w)| w.to_state()).collect()
781884
}
782885
}
@@ -790,7 +893,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
790893
async fn add_worker(&self, worker: Worker) -> Result<(), Error> {
791894
let worker_id = worker.id.clone();
792895
let worker_timestamp = worker.last_update_timestamp;
793-
let mut inner = self.inner.lock().await;
896+
let mut inner = self.inner.write().await;
794897
if inner.shutting_down {
795898
warn!("Rejected worker add during shutdown: {}", worker_id);
796899
return Err(make_err!(
@@ -833,7 +936,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
833936
UpdateOperationType::UpdateWithError(_) | UpdateOperationType::UpdateWithDisconnect
834937
);
835938

836-
let mut inner = self.inner.lock().await;
939+
let mut inner = self.inner.write().await;
837940
let result = inner.update_action(worker_id, operation_id, update).await;
838941

839942
// Record action completion metric
@@ -851,7 +954,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
851954
timestamp: WorkerTimestamp,
852955
) -> Result<(), Error> {
853956
{
854-
let mut inner = self.inner.lock().await;
957+
let mut inner = self.inner.write().await;
855958
inner
856959
.refresh_lifetime(worker_id, timestamp)
857960
.err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()")?;
@@ -866,7 +969,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
866969
async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error> {
867970
self.worker_registry.remove_worker(worker_id).await;
868971

869-
let mut inner = self.inner.lock().await;
972+
let mut inner = self.inner.write().await;
870973
let result = inner
871974
.immediate_evict_worker(
872975
worker_id,
@@ -882,7 +985,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
882985
}
883986

884987
async fn shutdown(&self, shutdown_guard: ShutdownGuard) {
885-
let mut inner = self.inner.lock().await;
988+
let mut inner = self.inner.write().await;
886989
inner.shutting_down = true; // should reject further worker registration
887990
while let Some(worker_id) = inner
888991
.workers
@@ -910,8 +1013,9 @@ impl WorkerScheduler for ApiWorkerScheduler {
9101013
let now = UNIX_EPOCH + Duration::from_secs(now_timestamp);
9111014
let timeout_threshold = now_timestamp.saturating_sub(self.worker_timeout_s);
9121015

1016+
// Phase 1: Read-only collection of workers to check
9131017
let workers_to_check: Vec<(WorkerId, bool)> = {
914-
let inner = self.inner.lock().await;
1018+
let inner = self.inner.read().await;
9151019
inner
9161020
.workers
9171021
.iter()
@@ -949,7 +1053,8 @@ impl WorkerScheduler for ApiWorkerScheduler {
9491053
return Ok(());
9501054
}
9511055

952-
let mut inner = self.inner.lock().await;
1056+
// Phase 2: Write lock to remove timed out workers
1057+
let mut inner = self.inner.write().await;
9531058
let mut result = Ok(());
9541059

9551060
for worker_id in &worker_ids_to_remove {
@@ -976,7 +1081,7 @@ impl WorkerScheduler for ApiWorkerScheduler {
9761081
}
9771082

9781083
async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error> {
979-
let mut inner = self.inner.lock().await;
1084+
let mut inner = self.inner.write().await;
9801085
inner.set_drain_worker(worker_id, is_draining).await?;
9811086
self.worker_scheduler_metrics.record_worker_count(inner.workers.len());
9821087
Ok(())

0 commit comments

Comments
 (0)