diff --git a/hil/src/commands/cmd.rs b/hil/src/commands/cmd.rs index 6204e7d0c..aef482de3 100644 --- a/hil/src/commands/cmd.rs +++ b/hil/src/commands/cmd.rs @@ -23,6 +23,7 @@ use crate::OrbConfig; const PATTERN_START: &str = "hil_pattern_start-"; const PATTERN_END: &str = "-hil_pattern_end"; +const SHELL_PROMPT_PATTERNS: &[&str] = &["worldcoin@id", "root@"]; #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum CommandTransport { @@ -134,13 +135,13 @@ async fn run_inner( // Type newline to force a prompt (helps make sure we are in the state we // think we are in) type_str(&mut serial_writer, "\n").await?; - wait_for_str(&mut serial_stream, "worldcoin@id", timeout) + wait_for_prompt(&mut serial_stream, timeout) .await .wrap_err("failed while listening for prompt after newline")?; // Run cmd type_str(&mut serial_writer, &format!("stty -echo; {}\n\n", cmd)).await?; - wait_for_str(&mut serial_stream, "worldcoin@id", timeout) + wait_for_prompt(&mut serial_stream, timeout) .await .wrap_err("failed while listening for prompt after command")?; @@ -215,28 +216,73 @@ async fn type_str(mut serial_writer: impl AsyncWrite + Unpin, s: &str) -> Result .wrap_err_with(|| format!("failed to type {s}")) } -/// Returns when `pattern` is detected in the `serial_stream`. +/// Returns when a shell prompt is detected in the `serial_stream`. /// /// Includes timeouts. -async fn wait_for_str( - serial_stream: impl TryStream, - pattern: &str, +async fn wait_for_prompt( + serial_stream: impl TryStream + Unpin, timeout: Duration, ) -> Result<()> where E: std::error::Error + Send + Sync + 'static, { + let patterns = SHELL_PROMPT_PATTERNS.join(", "); tokio::time::timeout( timeout, - crate::serial::wait_for_pattern(pattern.as_bytes().to_vec(), serial_stream), + wait_for_any_pattern(serial_stream, SHELL_PROMPT_PATTERNS), ) .await - .wrap_err_with(|| format!("timeout while waiting for {pattern}"))? - .wrap_err_with(|| format!("error while waiting for {pattern}")) + .wrap_err_with(|| format!("timeout while waiting for one of {patterns}"))? + .map(|matched| { + debug!("detected shell prompt pattern: {matched:?}"); + }) + .wrap_err_with(|| format!("error while waiting for one of {patterns}")) +} + +async fn wait_for_any_pattern<'a, E>( + mut serial_stream: impl TryStream + Unpin, + patterns: &'a [&'a str], +) -> Result<&'a str, WaitErr> { + let max_pattern_len = patterns + .iter() + .map(|pattern| pattern.len()) + .max() + .expect("at least one prompt pattern"); + let mut buf = Vec::with_capacity(max_pattern_len * 2); + + loop { + let Some(chunk) = serial_stream.try_next().await? else { + break; + }; + + buf.extend_from_slice(&chunk); + if let Some(pattern) = find_matching_pattern(&buf, patterns) { + return Ok(pattern); + } + + let keep_len = max_pattern_len.saturating_sub(1); + if buf.len() > keep_len { + buf.drain(..buf.len() - keep_len); + } + } + + Err(WaitErr::StreamEnded) +} + +fn find_matching_pattern<'a>(buf: &[u8], patterns: &'a [&'a str]) -> Option<&'a str> { + patterns + .iter() + .find(|pattern| { + buf.windows(pattern.len()) + .any(|window| window == pattern.as_bytes()) + }) + .copied() } #[cfg(test)] mod test { + use std::convert::Infallible; + use super::*; fn sample_cmd() -> Cmd { @@ -262,4 +308,53 @@ mod test { }; assert!(cmd.transport.remote_transport().is_none()); } + + #[tokio::test] + async fn wait_for_any_pattern_detects_worldcoin_prompt() { + let stream = futures::stream::iter( + [Bytes::from_static(b"abc worldcoin@id-123:~$")] + .into_iter() + .map(Ok::<_, Infallible>), + ); + + assert_eq!( + wait_for_any_pattern(stream, SHELL_PROMPT_PATTERNS) + .await + .expect("prompt should be detected"), + "worldcoin@id" + ); + } + + #[tokio::test] + async fn wait_for_any_pattern_detects_root_prompt_split_across_chunks() { + let stream = futures::stream::iter( + [ + Bytes::from_static(b"boot noise ro"), + Bytes::from_static(b"ot@id-123:~#"), + ] + .into_iter() + .map(Ok::<_, Infallible>), + ); + + assert_eq!( + wait_for_any_pattern(stream, SHELL_PROMPT_PATTERNS) + .await + .expect("prompt should be detected"), + "root@" + ); + } + + #[tokio::test] + async fn wait_for_any_pattern_returns_stream_ended_without_prompt() { + let stream = futures::stream::iter( + [Bytes::from_static(b"boot noise")] + .into_iter() + .map(Ok::<_, Infallible>), + ); + + assert!(matches!( + wait_for_any_pattern(stream, SHELL_PROMPT_PATTERNS).await, + Err(WaitErr::StreamEnded) + )); + } }