Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 104 additions & 9 deletions hil/src/commands/cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::OrbConfig;

const PATTERN_START: &str = "hil_pattern_start-";
const PATTERN_END: &str = "-hil_pattern_end";
const SHELL_PROMPT_PATTERNS: &[&str] = &["worldcoin@id", "root@"];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use a specific root prompt token

Adding "root@" to SHELL_PROMPT_PATTERNS makes wait_for_prompt succeed on any command output containing that substring, not just an actual shell prompt. In serial mode this can occur in normal output (for example log lines or text with root@...), so the post-command wait may return early and the code will issue echo $? while the original command is still running, producing incorrect status or interleaved input. Please match a prompt-specific token (e.g., the full Orb host prefix and/or prompt suffix) instead of a generic substring.

Useful? React with 👍 / 👎.


#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum CommandTransport {
Expand Down Expand Up @@ -134,13 +135,13 @@ async fn run_inner(
// Type newline to force a prompt (helps make sure we are in the state we
// think we are in)
type_str(&mut serial_writer, "\n").await?;
wait_for_str(&mut serial_stream, "worldcoin@id", timeout)
wait_for_prompt(&mut serial_stream, timeout)
.await
.wrap_err("failed while listening for prompt after newline")?;

// Run cmd
type_str(&mut serial_writer, &format!("stty -echo; {}\n\n", cmd)).await?;
wait_for_str(&mut serial_stream, "worldcoin@id", timeout)
wait_for_prompt(&mut serial_stream, timeout)
.await
.wrap_err("failed while listening for prompt after command")?;

Expand Down Expand Up @@ -215,28 +216,73 @@ async fn type_str(mut serial_writer: impl AsyncWrite + Unpin, s: &str) -> Result
.wrap_err_with(|| format!("failed to type {s}"))
}

/// Returns when `pattern` is detected in the `serial_stream`.
/// Returns when a shell prompt is detected in the `serial_stream`.
///
/// Includes timeouts.
async fn wait_for_str<E>(
serial_stream: impl TryStream<Ok = Bytes, Error = E>,
pattern: &str,
async fn wait_for_prompt<E>(
serial_stream: impl TryStream<Ok = Bytes, Error = E> + Unpin,
timeout: Duration,
) -> Result<()>
where
E: std::error::Error + Send + Sync + 'static,
{
let patterns = SHELL_PROMPT_PATTERNS.join(", ");
tokio::time::timeout(
timeout,
crate::serial::wait_for_pattern(pattern.as_bytes().to_vec(), serial_stream),
wait_for_any_pattern(serial_stream, SHELL_PROMPT_PATTERNS),
)
.await
.wrap_err_with(|| format!("timeout while waiting for {pattern}"))?
.wrap_err_with(|| format!("error while waiting for {pattern}"))
.wrap_err_with(|| format!("timeout while waiting for one of {patterns}"))?
.map(|matched| {
debug!("detected shell prompt pattern: {matched:?}");
})
.wrap_err_with(|| format!("error while waiting for one of {patterns}"))
}

async fn wait_for_any_pattern<'a, E>(
mut serial_stream: impl TryStream<Ok = Bytes, Error = E> + Unpin,
patterns: &'a [&'a str],
) -> Result<&'a str, WaitErr<E>> {
let max_pattern_len = patterns
.iter()
.map(|pattern| pattern.len())
.max()
.expect("at least one prompt pattern");
let mut buf = Vec::with_capacity(max_pattern_len * 2);

loop {
let Some(chunk) = serial_stream.try_next().await? else {
break;
};

buf.extend_from_slice(&chunk);
if let Some(pattern) = find_matching_pattern(&buf, patterns) {
return Ok(pattern);
}

let keep_len = max_pattern_len.saturating_sub(1);
if buf.len() > keep_len {
buf.drain(..buf.len() - keep_len);
}
}

Err(WaitErr::StreamEnded)
}

fn find_matching_pattern<'a>(buf: &[u8], patterns: &'a [&'a str]) -> Option<&'a str> {
patterns
.iter()
.find(|pattern| {
buf.windows(pattern.len())
.any(|window| window == pattern.as_bytes())
})
.copied()
}

#[cfg(test)]
mod test {
use std::convert::Infallible;

use super::*;

fn sample_cmd() -> Cmd {
Expand All @@ -262,4 +308,53 @@ mod test {
};
assert!(cmd.transport.remote_transport().is_none());
}

#[tokio::test]
async fn wait_for_any_pattern_detects_worldcoin_prompt() {
let stream = futures::stream::iter(
[Bytes::from_static(b"abc worldcoin@id-123:~$")]
.into_iter()
.map(Ok::<_, Infallible>),
);

assert_eq!(
wait_for_any_pattern(stream, SHELL_PROMPT_PATTERNS)
.await
.expect("prompt should be detected"),
"worldcoin@id"
);
}

#[tokio::test]
async fn wait_for_any_pattern_detects_root_prompt_split_across_chunks() {
let stream = futures::stream::iter(
[
Bytes::from_static(b"boot noise ro"),
Bytes::from_static(b"ot@id-123:~#"),
]
.into_iter()
.map(Ok::<_, Infallible>),
);

assert_eq!(
wait_for_any_pattern(stream, SHELL_PROMPT_PATTERNS)
.await
.expect("prompt should be detected"),
"root@"
);
}

#[tokio::test]
async fn wait_for_any_pattern_returns_stream_ended_without_prompt() {
let stream = futures::stream::iter(
[Bytes::from_static(b"boot noise")]
.into_iter()
.map(Ok::<_, Infallible>),
);

assert!(matches!(
wait_for_any_pattern(stream, SHELL_PROMPT_PATTERNS).await,
Err(WaitErr::StreamEnded)
));
}
}
Loading