diff --git a/orb-jobs-agent/src/args.rs b/orb-jobs-agent/src/args.rs index 2e544ea8e..ae4df561d 100644 --- a/orb-jobs-agent/src/args.rs +++ b/orb-jobs-agent/src/args.rs @@ -36,6 +36,14 @@ pub struct Args { /// The target job-server service id to send messages to. #[clap(long, env = "TARGET_SERVICE_ID", default_value = "job-server")] pub target_service_id: Option, + #[clap( + long, + env = "DBUS_SESSION_BUS_ADDRESS", + default_value = "unix:path=/tmp/worldcoin_bus_socket" + )] + pub dbus_addr: String, + #[clap(long)] + pub run_job: Option, } fn clap_v3_styles() -> Styles { diff --git a/orb-jobs-agent/src/job_system/client.rs b/orb-jobs-agent/src/job_system/client.rs index 67f40ade6..1ab089a7a 100644 --- a/orb-jobs-agent/src/job_system/client.rs +++ b/orb-jobs-agent/src/job_system/client.rs @@ -2,45 +2,81 @@ use crate::job_system::{ orchestrator::{JobConfig, JobRegistry}, sanitize::redact_job_document, }; +use async_trait::async_trait; use color_eyre::eyre::{eyre, Result}; use orb_relay_client::{Client, QoS, SendMessage}; use orb_relay_messages::{ jobs::v1::{ - JobCancel, JobExecution, JobExecutionUpdate, JobNotify, JobRequestNext, + JobCancel, JobExecution, JobExecutionStatus, JobExecutionUpdate, JobNotify, + JobRequestNext, }, prost::{Message, Name}, prost_types::Any, relay::entity::EntityType, }; +use std::sync::{Arc, Mutex}; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; use tracing::{error, info, warn}; +pub type TransportResult = std::result::Result; + +#[derive(Debug)] +pub enum JobTransportMessage { + Notify, + Execution(JobExecution), + Cancel(JobCancel), +} + +#[async_trait] +pub trait JobTransport: Send + Sync + std::fmt::Debug { + async fn recv(&self) -> TransportResult; + + async fn request_next_job(&self, request: &JobRequestNext) -> TransportResult<()>; + + async fn send_job_update(&self, update: &JobExecutionUpdate) + -> TransportResult<()>; + + async fn reconnect(&self) -> Result<()>; +} + #[derive(Debug, Clone)] -pub struct JobClient { +pub struct RelayTransport { relay_client: Client, target_service_id: String, relay_namespace: String, - job_registry: JobRegistry, - job_config: JobConfig, } -impl JobClient { +impl RelayTransport { pub fn new( relay_client: Client, - target_service_id: &str, - relay_namespace: &str, - job_registry: JobRegistry, - job_config: JobConfig, + target_service_id: impl Into, + relay_namespace: impl Into, ) -> Self { Self { relay_client, - target_service_id: target_service_id.to_string(), - relay_namespace: relay_namespace.to_string(), - job_registry, - job_config, + target_service_id: target_service_id.into(), + relay_namespace: relay_namespace.into(), } } - pub async fn listen_for_job(&self) -> Result { + async fn send_request(&self, request: &JobRequestNext) -> TransportResult<()> { + let any = Any::from_msg(request).unwrap(); + self.relay_client + .send( + SendMessage::to(EntityType::Service) + .id(self.target_service_id.clone()) + .namespace(self.relay_namespace.clone()) + .qos(QoS::AtLeastOnce) + .payload(any.encode_to_vec()), + ) + .await + } +} + +#[async_trait] +impl JobTransport for RelayTransport { + async fn recv(&self) -> TransportResult { loop { match self.relay_client.recv().await { Ok(msg) => { @@ -55,7 +91,7 @@ impl JobClient { match JobNotify::decode(any.value.as_slice()) { Ok(job_notify) => { info!("received JobNotify: {:?}", job_notify); - let _ = self.request_next_job().await; + return Ok(JobTransportMessage::Notify); } Err(e) => { error!("error decoding JobNotify: {:?}", e); @@ -71,7 +107,7 @@ impl JobClient { should_cancel = job.should_cancel, "received JobExecution" ); - return Ok(job); + return Ok(JobTransportMessage::Execution(job)); } Err(e) => { error!("error decoding JobExecution: {:?}", e); @@ -84,21 +120,7 @@ impl JobClient { job_execution_id = %job_cancel.job_execution_id, "received JobCancel" ); - let cancelled = self - .job_registry - .cancel_job(&job_cancel.job_execution_id) - .await; - if cancelled { - info!( - job_execution_id = %job_cancel.job_execution_id, - "Successfully cancelled job" - ); - } else { - warn!( - job_execution_id = %job_cancel.job_execution_id, - "Attempted to cancel non-existent or already completed job" - ); - } + return Ok(JobTransportMessage::Cancel(job_cancel)); } Err(e) => { error!("error decoding JobCancel: {:?}", e); @@ -116,66 +138,25 @@ impl JobClient { } } - /// Requests for a next job to be run, excluding the ones that are - /// currently running (determined by `running_job_execution_ids` arg) - pub async fn request_next_job(&self) -> Result<(), orb_relay_client::Err> { - let mut running_ids = self.job_registry.get_active_job_ids().await; - let mut completed_ids = self.job_registry.get_completed_job_ids().await; - - running_ids.append(&mut completed_ids); - let job_ids_to_ignore = running_ids; - - let job_request = JobRequestNext { - ignore_job_execution_ids: job_ids_to_ignore.clone(), - }; - - let any = Any::from_msg(&job_request).unwrap(); - self.relay_client - .send( - SendMessage::to(EntityType::Service) - .id(self.target_service_id.clone()) - .namespace(self.relay_namespace.clone()) - .qos(QoS::AtLeastOnce) - .payload(any.encode_to_vec()), - ) - .await?; + async fn request_next_job( + &self, + job_request: &JobRequestNext, + ) -> TransportResult<()> { + self.send_request(job_request).await?; info!( "sent JobRequestNext ignoring {} job execution IDs: {:?}", - job_ids_to_ignore.len(), - job_ids_to_ignore + job_request.ignore_job_execution_ids.len(), + job_request.ignore_job_execution_ids ); Ok(()) } - /// Check if we should request more jobs and do so if appropriate - /// This method is used to implement parallel job execution - /// Returns `false` if no jobs were requested. - pub async fn try_request_more_jobs(&self) -> Result { - // Check if we should request more jobs based on current configuration - if !self - .job_config - .should_request_more_jobs(&self.job_registry) - .await - { - return Ok(false); - } - - // Request next job with current running job IDs - self.request_next_job() - .await - .inspect_err(|e| error!("Failed to request additional job: {:?}", e))?; - - info!("Successfully requested additional job for parallel execution"); - - Ok(true) - } - - pub async fn send_job_update( + async fn send_job_update( &self, job_update: &JobExecutionUpdate, - ) -> Result<(), orb_relay_client::Err> { + ) -> TransportResult<()> { info!( job_execution_id = %job_update.job_execution_id, job_id = %job_update.job_id, @@ -210,7 +191,7 @@ impl JobClient { Ok(()) } - pub async fn force_relay_reconnect(&self) -> Result<()> { + async fn reconnect(&self) -> Result<()> { self.relay_client .reconnect() .await @@ -218,6 +199,190 @@ impl JobClient { } } +#[derive(Debug)] +pub struct LocalTransport { + pending_job: Mutex>, + terminal_update: Mutex>, + shutdown: CancellationToken, +} + +impl LocalTransport { + pub fn new(job: JobExecution) -> Self { + Self { + pending_job: Mutex::new(Some(job)), + terminal_update: Mutex::new(None), + shutdown: CancellationToken::new(), + } + } + + pub fn terminal_update(&self) -> Option { + self.terminal_update.lock().unwrap().clone() + } + + pub fn shutdown_handle(&self) -> JoinHandle> { + let shutdown = self.shutdown.clone(); + tokio::spawn(async move { + shutdown.cancelled().await; + Ok(()) + }) + } +} + +#[async_trait] +impl JobTransport for LocalTransport { + async fn recv(&self) -> TransportResult { + let next_job = self.pending_job.lock().unwrap().take(); + + if let Some(job) = next_job { + info!( + job_id = %job.job_id, + job_execution_id = %job.job_execution_id, + job_document = %redact_job_document(&job.job_document), + should_cancel = job.should_cancel, + "received local JobExecution" + ); + + return Ok(JobTransportMessage::Execution(job)); + } + + std::future::pending::<()>().await; + unreachable!() + } + + async fn request_next_job(&self, _request: &JobRequestNext) -> TransportResult<()> { + Ok(()) + } + + async fn send_job_update( + &self, + job_update: &JobExecutionUpdate, + ) -> TransportResult<()> { + let status_name = JobExecutionStatus::try_from(job_update.status) + .map(|status| format!("{status:?}")) + .unwrap_or_else(|_| format!("Unknown({})", job_update.status)); + + println!("--- Job Update ---"); + println!("job_id: {}", job_update.job_id); + println!("job_execution_id: {}", job_update.job_execution_id); + println!("status: {status_name}"); + + if !job_update.std_out.is_empty() { + println!("stdout:\n{}", job_update.std_out); + } + + if !job_update.std_err.is_empty() { + eprintln!("stderr:\n{}", job_update.std_err); + } + + if job_update.status != JobExecutionStatus::InProgress as i32 { + *self.terminal_update.lock().unwrap() = Some(job_update.clone()); + self.shutdown.cancel(); + } + + Ok(()) + } + + async fn reconnect(&self) -> Result<()> { + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct JobClient { + transport: Arc, + job_registry: JobRegistry, + job_config: JobConfig, +} + +impl JobClient { + pub fn new( + transport: Arc, + job_registry: JobRegistry, + job_config: JobConfig, + ) -> Self { + Self { + transport, + job_registry, + job_config, + } + } + + pub async fn listen_for_job(&self) -> TransportResult { + loop { + match self.transport.recv().await? { + JobTransportMessage::Notify => { + let _ = self.request_next_job().await; + } + JobTransportMessage::Execution(job) => { + return Ok(job); + } + JobTransportMessage::Cancel(job_cancel) => { + let cancelled = self + .job_registry + .cancel_job(&job_cancel.job_execution_id) + .await; + + if cancelled { + info!( + job_execution_id = %job_cancel.job_execution_id, + "Successfully cancelled job" + ); + } else { + warn!( + job_execution_id = %job_cancel.job_execution_id, + "Attempted to cancel non-existent or already completed job" + ); + } + } + } + } + } + + pub async fn request_next_job(&self) -> TransportResult<()> { + let job_request = build_job_request(&self.job_registry).await; + self.transport.request_next_job(&job_request).await + } + + pub async fn try_request_more_jobs(&self) -> TransportResult { + if !self + .job_config + .should_request_more_jobs(&self.job_registry) + .await + { + return Ok(false); + } + + self.request_next_job() + .await + .inspect_err(|e| error!("Failed to request additional job: {:?}", e))?; + + info!("Successfully requested additional job for parallel execution"); + + Ok(true) + } + + pub async fn send_job_update( + &self, + job_update: &JobExecutionUpdate, + ) -> TransportResult<()> { + self.transport.send_job_update(job_update).await + } + + pub async fn force_relay_reconnect(&self) -> Result<()> { + self.transport.reconnect().await + } +} + +async fn build_job_request(job_registry: &JobRegistry) -> JobRequestNext { + let mut running_ids = job_registry.get_active_job_ids().await; + let mut completed_ids = job_registry.get_completed_job_ids().await; + running_ids.append(&mut completed_ids); + + JobRequestNext { + ignore_job_execution_ids: running_ids, + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/orb-jobs-agent/src/job_system/handler.rs b/orb-jobs-agent/src/job_system/handler.rs index 81e86b268..da9f9f01a 100644 --- a/orb-jobs-agent/src/job_system/handler.rs +++ b/orb-jobs-agent/src/job_system/handler.rs @@ -1,21 +1,18 @@ use super::ctx::Ctx; use crate::{ job_system::{ - client::JobClient, + client::{JobClient, JobTransport, TransportResult}, ctx::JobExecutionUpdateExt, orchestrator::{JobCompletion, JobConfig, JobRegistry, JobStartStatus}, sanitize::{redact_args, redact_job_document, should_sanitize}, }, program::Deps, - settings::Settings, }; use color_eyre::Result; -use orb_relay_client::{Client, ClientOpts}; -use orb_relay_messages::{ - jobs::v1::{JobExecution, JobExecutionStatus, JobExecutionUpdate}, - relay::entity::EntityType, +use orb_relay_messages::jobs::v1::{ + JobExecution, JobExecutionStatus, JobExecutionUpdate, }; -use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; +use std::{collections::HashMap, pin::Pin, sync::Arc}; use tokio::{sync::oneshot, task::JoinHandle}; use tokio_util::sync::CancellationToken; use tracing::{error, info, warn}; @@ -74,8 +71,13 @@ impl JobHandlerBuilder { self } - pub fn build(self, deps: Deps) -> JobHandler { - JobHandler::new(self, deps) + pub fn build( + self, + deps: Deps, + transport: Arc, + transport_handle: JoinHandle>, + ) -> JobHandler { + JobHandler::new(self, deps, transport, transport_handle) } } @@ -106,7 +108,7 @@ pub struct JobHandler { job_config: JobConfig, job_registry: JobRegistry, pub(crate) job_client: JobClient, - relay_handle: JoinHandle>, + transport_handle: JoinHandle>, handlers: HashMap, } @@ -118,51 +120,28 @@ impl JobHandler { } } - fn new(builder: JobHandlerBuilder, deps: Deps) -> Self { - let Settings { - orb_id, - relay_host, - relay_namespace, - target_service_id, - auth, - .. - } = &deps.settings; - - let opts = ClientOpts::entity(EntityType::Orb) - .id(orb_id.as_str().to_string()) - .endpoint(relay_host) - .namespace(relay_namespace) - .auth(auth.clone()) - .connection_timeout(Duration::from_secs(3)) - .connection_backoff(Duration::from_secs(2)) - .keep_alive_interval(Duration::from_secs(30)) - .keep_alive_timeout(Duration::from_secs(10)) - .ack_timeout(Duration::from_secs(5)) - .build(); - - info!("Connecting to relay: {:?}", relay_host); - let (relay_client, relay_handle) = Client::connect(opts); + fn new( + builder: JobHandlerBuilder, + deps: Deps, + transport: Arc, + transport_handle: JoinHandle>, + ) -> Self { let job_registry = JobRegistry::new(); let job_config = builder.job_config; - let job_client = JobClient::new( - relay_client.clone(), - target_service_id.as_str(), - relay_namespace, - job_registry.clone(), - job_config.clone(), - ); + let job_client = + JobClient::new(transport, job_registry.clone(), job_config.clone()); Self { state: Arc::new(deps), job_config, job_registry, job_client, - relay_handle, + transport_handle, handlers: builder.handlers.into_iter().collect(), } } - pub async fn run(mut self) { + pub async fn run(mut self) -> Result<()> { // Kickstart job requests. match self.job_client.try_request_more_jobs().await { Ok(true) => { @@ -186,16 +165,35 @@ impl JobHandler { loop { tokio::select! { - _ = &mut self.relay_handle => { + transport_result = &mut self.transport_handle => { + match transport_result { + Ok(Ok(())) => {} + Ok(Err(e)) => { + error!("Transport shutdown with error: {:?}", e); + } + Err(e) => { + error!("Transport task failed: {:?}", e); + } + } info!("Relay service shutdown detected"); break; } - Ok(job) = self.job_client.listen_for_job() => { - self = self.handle_job(job).await; + result = self.job_client.listen_for_job() => { + match result { + Ok(job) => { + self = self.handle_job(job).await; + } + Err(e) => { + error!("Failed to receive job: {:?}", e); + break; + } + } } } } + + Ok(()) } async fn handle_job(mut self, job: JobExecution) -> Self { diff --git a/orb-jobs-agent/src/main.rs b/orb-jobs-agent/src/main.rs index 9dd6fd4b0..e51c6c208 100644 --- a/orb-jobs-agent/src/main.rs +++ b/orb-jobs-agent/src/main.rs @@ -1,10 +1,22 @@ use clap::Parser; -use color_eyre::eyre::Result; -use orb_jobs_agent::args::Args; -use orb_jobs_agent::program::{self, Deps}; +use color_eyre::eyre::{eyre, Context, ContextCompat, Result}; +use orb_endpoints::{v1::Endpoints, Backend}; +use orb_info::TokenTaskHandle; use orb_jobs_agent::settings::Settings; use orb_jobs_agent::shell::Host; -use tracing::info; +use orb_jobs_agent::{ + args::Args, + job_system::client::{JobTransport, LocalTransport, RelayTransport}, + program::{self, Deps, Runtime}, +}; +use orb_relay_client::{Auth, Client, ClientOpts}; +use orb_relay_messages::{ + jobs::v1::{JobExecution, JobExecutionStatus}, + relay::entity::EntityType, +}; +use std::{sync::Arc, time::Duration}; +use tokio_util::sync::CancellationToken; +use tracing::{info, warn}; const SYSLOG_IDENTIFIER: &str = "worldcoin-jobs-agent"; @@ -24,16 +36,138 @@ async fn main() -> Result<()> { async fn run(args: &Args) -> Result<()> { info!("Starting jobs agent: {:?}", args); - let connection = zbus::Connection::session().await?; + let connection = zbus::ConnectionBuilder::address(args.dbus_addr.as_str())? + .build() + .await?; - let deps = Deps::new( - Host, - connection, - Settings::from_args(args, "/mnt/scratch").await?, - ); + let settings = Settings::from_args(args, "/mnt/scratch").await?; + let deps = Deps::new(Host, connection, settings.clone()); - program::run(deps).await?; + match &args.run_job { + Some(job_document) => run_local(deps, job_document).await?, + None => run_service(deps, args, &settings).await?, + } info!("Shutting down jobs agent completed"); Ok(()) } + +async fn run_local(deps: Deps, job_document: &str) -> Result<()> { + let job = JobExecution { + job_id: "local-job".to_string(), + job_execution_id: "local-job-execution".to_string(), + job_document: job_document.to_string(), + should_cancel: false, + }; + + let transport = Arc::new(LocalTransport::new(job)); + let runtime = Runtime { + transport: transport.clone(), + transport_handle: transport.shutdown_handle(), + watch_conn_changes: false, + }; + + program::run(deps, runtime).await?; + + let terminal_update = transport + .terminal_update() + .ok_or_else(|| eyre!("local run ended without terminal job status"))?; + + if terminal_update.status != JobExecutionStatus::Succeeded as i32 { + let status_name = JobExecutionStatus::try_from(terminal_update.status) + .map(|status| format!("{status:?}")) + .unwrap_or_else(|_| format!("Unknown({})", terminal_update.status)); + + if terminal_update.std_err.is_empty() { + return Err(eyre!("local job failed with status {status_name}")); + } + + return Err(eyre!( + "local job failed with status {status_name}: {}", + terminal_update.std_err + )); + } + + Ok(()) +} + +async fn run_service(deps: Deps, args: &Args, settings: &Settings) -> Result<()> { + let relay_host = args + .relay_host + .clone() + .or_else(|| { + Backend::from_env().ok().map(|backend| { + Endpoints::new(backend, &settings.orb_id).relay.to_string() + }) + }) + .wrap_err("could not get Backend Endpoint from env")?; + + let auth = resolve_auth(args, &deps.session_dbus).await?; + + let relay_namespace = args + .relay_namespace + .clone() + .wrap_err("relay namespace MUST be provided")?; + + let target_service_id = args + .target_service_id + .clone() + .wrap_err("target service id MUST be provided")?; + + let opts = ClientOpts::entity(EntityType::Orb) + .id(settings.orb_id.as_str().to_string()) + .endpoint(&relay_host) + .namespace(&relay_namespace) + .auth(auth) + .connection_timeout(Duration::from_secs(3)) + .connection_backoff(Duration::from_secs(2)) + .keep_alive_interval(Duration::from_secs(30)) + .keep_alive_timeout(Duration::from_secs(10)) + .ack_timeout(Duration::from_secs(5)) + .build(); + + info!("Connecting to relay: {:?}", relay_host); + let (relay_client, transport_handle) = Client::connect(opts); + let transport: Arc = Arc::new(RelayTransport::new( + relay_client, + target_service_id, + relay_namespace, + )); + let runtime = Runtime { + transport, + transport_handle, + watch_conn_changes: true, + }; + + program::run(deps, runtime).await +} + +async fn resolve_auth(args: &Args, connection: &zbus::Connection) -> Result { + match &args.orb_token { + Some(token) => Ok(Auth::Token(token.as_str().into())), + None => { + let shutdown = CancellationToken::new(); + let get_token = async || { + TokenTaskHandle::spawn(connection, &shutdown) + .await + .wrap_err("failed to get auth token!") + }; + + let token_recv = tokio::time::timeout(Duration::from_secs(60), async { + loop { + match get_token().await { + Ok(handle) => return handle.token_recv, + Err(e) => { + warn!("{e}! trying again in 5s"); + tokio::time::sleep(Duration::from_secs(5)).await; + } + } + } + }) + .await + .wrap_err("could not get auth token after 60s")?; + + Ok(Auth::TokenReceiver(token_recv)) + } + } +} diff --git a/orb-jobs-agent/src/program.rs b/orb-jobs-agent/src/program.rs index c3b6dfc3c..850ed17dc 100644 --- a/orb-jobs-agent/src/program.rs +++ b/orb-jobs-agent/src/program.rs @@ -7,12 +7,16 @@ use crate::{ update_versions, wifi_add, wifi_connect, wifi_ip, wifi_list, wifi_remove, wifi_scan, wipe_downloads, }, - job_system::handler::JobHandler, + job_system::{ + client::{JobTransport, TransportResult}, + handler::JobHandler, + }, settings::Settings, shell::Shell, }; use color_eyre::Result; -use tokio::fs; +use std::sync::Arc; +use tokio::{fs, task::JoinHandle}; /// Dependencies used by the jobs-agent. pub struct Deps { @@ -21,6 +25,12 @@ pub struct Deps { pub settings: Settings, } +pub struct Runtime { + pub transport: Arc, + pub transport_handle: JoinHandle>, + pub watch_conn_changes: bool, +} + impl Deps { pub fn new(shell: S, session_dbus: zbus::Connection, settings: Settings) -> Self where @@ -34,7 +44,7 @@ impl Deps { } } -pub async fn run(deps: Deps) -> Result<()> { +pub async fn run(deps: Deps, runtime: Runtime) -> Result<()> { fs::create_dir_all(&deps.settings.store_path).await?; let orb_id = deps.settings.orb_id.clone(); let zenoh_port = deps.settings.zenoh_port; @@ -70,13 +80,22 @@ pub async fn run(deps: Deps) -> Result<()> { .parallel_max("logs", 3, logs::handler) .sequential("reboot", reboot::handler) .sequential("slot_switch", slot_switch::handler) - .build(deps); + .build(deps, runtime.transport, runtime.transport_handle); - let _zenoh_session = - conn_change::spawn_watcher(orb_id, job_handler.job_client.clone(), zenoh_port) - .await?; + let _zenoh_session = if runtime.watch_conn_changes { + Some( + conn_change::spawn_watcher( + orb_id, + job_handler.job_client.clone(), + zenoh_port, + ) + .await?, + ) + } else { + None + }; - job_handler.run().await; + job_handler.run().await?; Ok(()) } diff --git a/orb-jobs-agent/src/settings.rs b/orb-jobs-agent/src/settings.rs index 4e0f82e65..05c25b903 100644 --- a/orb-jobs-agent/src/settings.rs +++ b/orb-jobs-agent/src/settings.rs @@ -1,32 +1,18 @@ use crate::args::Args; -use color_eyre::{ - eyre::{eyre, Context, ContextCompat}, - Result, -}; -use orb_endpoints::{v1::Endpoints, Backend}; +use color_eyre::{eyre::Context, Result}; use orb_info::{ orb_os_release::{OrbOsPlatform, OrbOsRelease}, - OrbId, TokenTaskHandle, + OrbId, }; -use orb_relay_client::Auth; use std::{ path::{Path, PathBuf}, str::FromStr, - time::Duration, }; -use tokio::time; -use tokio_util::sync::CancellationToken; -use tracing::warn; -use zbus::Connection; #[derive(Debug, Clone)] pub struct Settings { pub orb_id: OrbId, pub orb_platform: OrbOsPlatform, - pub auth: Auth, - pub relay_host: String, - pub relay_namespace: String, - pub target_service_id: String, /// Filesystem path used to persist data pub store_path: PathBuf, /// Path to the calibration file (configurable for testing) @@ -63,63 +49,6 @@ impl Settings { os_release.orb_os_platform_type }; - let relay_host = args - .relay_host - .clone() - .or_else(|| { - Backend::from_env() - .ok() - .map(|backend| Endpoints::new(backend, &orb_id).relay.to_string()) - }) - .wrap_err("could not get Backend Endpoint from env")?; - - // Get token from DBus - let auth = match &args.orb_token { - Some(t) => Auth::Token(t.as_str().into()), - None => { - let shutdown_token = CancellationToken::new(); - let get_token = async || { - let connection = Connection::session() - .await - .map_err(|e| eyre!("failed to establish zbus conn: {e}"))?; - - TokenTaskHandle::spawn(&connection, &shutdown_token) - .await - .wrap_err("failed to get auth token!") - }; - - let token_rec_fut = async { - loop { - match get_token().await { - Err(e) => { - warn!("{e}! trying again in 5s"); - time::sleep(Duration::from_secs(5)).await; - continue; - } - - Ok(t) => break t.token_recv, - } - } - }; - - let token_rec = time::timeout(Duration::from_secs(60), token_rec_fut) - .await - .wrap_err("could not get auth token after 60s")?; - - Auth::TokenReceiver(token_rec) - } - }; - - let relay_namespace = args - .relay_namespace - .clone() - .wrap_err("relay namespace MUST be provided")?; - - let target_service_id = args - .target_service_id - .clone() - .wrap_err("target service id MUST be provided")?; - let downloads_path = match orb_platform { OrbOsPlatform::Diamond => PathBuf::from("/mnt/scratch"), OrbOsPlatform::Pearl => PathBuf::from("/mnt/updates"), @@ -128,10 +57,6 @@ impl Settings { Ok(Self { orb_id, orb_platform, - auth, - relay_host, - relay_namespace, - target_service_id, store_path: store_path.as_ref().to_path_buf(), calibration_file_path: PathBuf::from("/usr/persistent/calibration.json"), os_release_path: PathBuf::from("/etc/os-release"), diff --git a/orb-jobs-agent/tests/common/fixture.rs b/orb-jobs-agent/tests/common/fixture.rs index 89d759822..96984584f 100644 --- a/orb-jobs-agent/tests/common/fixture.rs +++ b/orb-jobs-agent/tests/common/fixture.rs @@ -8,6 +8,7 @@ use dbus_launch::BusType; use orb_connd_dbus::Connd; use orb_info::OrbId; use orb_jobs_agent::{ + job_system::client::{JobTransport, RelayTransport, TransportResult}, program::{self, Deps}, settings::Settings, shell::Shell, @@ -27,6 +28,7 @@ use orb_relay_messages::{ }; use orb_relay_test_utils::{IntoRes, TestServer}; use orb_telemetry::TelemetryFlusher; +use std::sync::Arc; use std::time::Duration; use test_utils::async_bag::AsyncBag; use tokio::task::{self, JoinHandle}; @@ -42,6 +44,10 @@ pub struct JobAgentFixture { _server: TestServer<()>, client: Client, pub settings: Settings, + pub relay_host: String, + pub relay_namespace: String, + pub target_service_id: String, + pub auth: Auth, pub execution_updates: AsyncBag>, pub job_queue: JobQueue, _tempdir: TempDir, @@ -51,6 +57,32 @@ pub struct JobAgentFixture { zenoh_port: u16, } +impl JobAgentFixture { + pub fn connect_relay( + &self, + ) -> (Arc, JoinHandle>) { + let opts = ClientOpts::entity(EntityType::Orb) + .id(self.settings.orb_id.to_string()) + .namespace(self.relay_namespace.clone()) + .endpoint(self.relay_host.clone()) + .auth(self.auth.clone()) + .max_connection_attempts(Amount::Val(3)) + .connection_timeout(Duration::from_secs(1)) + .heartbeat(Duration::from_secs(u64::MAX)) + .ack_timeout(Duration::from_secs(1)) + .build(); + + let (relay_client, transport_handle) = Client::connect(opts); + let transport: Arc = Arc::new(RelayTransport::new( + relay_client, + self.target_service_id.clone(), + self.relay_namespace.clone(), + )); + + (transport, transport_handle) + } +} + #[bon] impl JobAgentFixture { pub fn init_tracing(&self) -> TelemetryFlusher { @@ -187,10 +219,6 @@ impl JobAgentFixture { let settings = Settings { orb_id: OrbId::Short(orb_id.parse().unwrap()), orb_platform: orb_info::orb_os_release::OrbOsPlatform::Diamond, - auth, - relay_host, - relay_namespace: namespace, - target_service_id: target_service_id.to_string(), store_path: tempdir.to_path_buf(), // Use non-existent paths by default for tests (can be overridden) calibration_file_path: "/nonexistent/calibration.json".into(), @@ -220,6 +248,10 @@ impl JobAgentFixture { _server: server, client, settings, + relay_host, + relay_namespace: namespace, + target_service_id, + auth, execution_updates, job_queue, _tempdir: tempdir, @@ -251,11 +283,16 @@ impl JobAgentFixture { .await .unwrap(); + let (transport, transport_handle) = self.connect_relay(); let deps = Deps::new(shell, self.dbus_conn.clone(), settings.clone()); let join_handle = task::spawn(async move { tokio::select! { - r = program::run(deps) => { + r = program::run(deps, program::Runtime { + transport, + transport_handle, + watch_conn_changes: true, + }) => { if let Err(e) = r { println!("program::run failed with {e}"); } @@ -302,7 +339,7 @@ impl JobAgentFixture { .send( SendMessage::to(EntityType::Orb) .id(self.settings.orb_id.to_string()) - .namespace(&self.settings.relay_namespace) + .namespace(&self.relay_namespace) .qos(QoS::AtLeastOnce) .payload(payload), ) @@ -323,7 +360,7 @@ impl JobAgentFixture { .send( SendMessage::to(EntityType::Orb) .id(self.settings.orb_id.to_string()) - .namespace(&self.settings.relay_namespace) + .namespace(&self.relay_namespace) .qos(QoS::AtLeastOnce) .payload(payload), ) diff --git a/orb-jobs-agent/tests/job_handler.rs b/orb-jobs-agent/tests/job_handler.rs index c99f4a26f..60a497e4f 100644 --- a/orb-jobs-agent/tests/job_handler.rs +++ b/orb-jobs-agent/tests/job_handler.rs @@ -19,6 +19,7 @@ async fn sequential_jobs_block_other_jobs_execution() { // Arrange let fx = JobAgentFixture::new().await; let deps = Deps::new(Host, fx.dbus_conn.clone(), fx.settings.clone()); + let (transport, transport_handle) = fx.connect_relay(); let wait_time = Duration::from_millis(100); @@ -29,7 +30,7 @@ async fn sequential_jobs_block_other_jobs_execution() { Ok(ctx.success().stdout("one")) }) .parallel("second", async |ctx| Ok(ctx.success().stdout("two"))) - .build(deps) + .build(deps, transport, transport_handle) .run(), ); @@ -47,6 +48,7 @@ async fn can_start_parallel_jobs_in_parallel() { // Arrange let fx = JobAgentFixture::new().await; let deps = Deps::new(Host, fx.dbus_conn.clone(), fx.settings.clone()); + let (transport, transport_handle) = fx.connect_relay(); let wait_time = Duration::from_millis(500); @@ -57,7 +59,7 @@ async fn can_start_parallel_jobs_in_parallel() { Ok(ctx.success().stdout("one")) }) .parallel("second", async |ctx| Ok(ctx.success().stdout("two"))) - .build(deps) + .build(deps, transport, transport_handle) .run(), ); @@ -81,8 +83,13 @@ async fn gracefully_handles_unsupported_cmds() { // Arrange let fx = JobAgentFixture::new().await; let deps = Deps::new(Host, fx.dbus_conn.clone(), fx.settings.clone()); + let (transport, transport_handle) = fx.connect_relay(); - task::spawn(JobHandler::builder().build(deps).run()); + task::spawn( + JobHandler::builder() + .build(deps, transport, transport_handle) + .run(), + ); // Act fx.enqueue_job("joberoni").await.wait_for_completion().await; @@ -97,6 +104,7 @@ async fn it_cancels_a_long_running_job() { // Arrange let fx = JobAgentFixture::with_namespace("cancel_long_running_job").await; let deps = Deps::new(Host, fx.dbus_conn.clone(), fx.settings.clone()); + let (transport, transport_handle) = fx.connect_relay(); let wait_time = Duration::from_millis(50); @@ -120,7 +128,7 @@ async fn it_cancels_a_long_running_job() { Ok(ctx.success().stdout("cancelled succesfully!")) }) - .build(deps) + .build(deps, transport, transport_handle) .run(), );