diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index f522f5d5..3b0ddec7 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -383,9 +383,10 @@ impl Stream { } /// Set the stream's state to `Closed` with the given reason and initiator. - /// Notify the send and receive tasks, if they exist. + /// Notify the send, receive, and push tasks, if they exist. pub(super) fn set_reset(&mut self, reason: Reason, initiator: Initiator) { self.state.set_reset(self.id, reason, initiator); + self.notify_send(); self.notify_push(); self.notify_recv(); } diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index 141bdcc9..2086938c 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1,6 +1,7 @@ use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; use h2_support::util::yield_once; +use tokio::sync::oneshot; // In this case, the stream & connection both have capacity, but capacity is not // explicitly requested. @@ -2563,3 +2564,109 @@ async fn goaway_ignores_data_but_returns_connection_capacity() { join(client, srv).await; } } + +/// When the library sends RST_STREAM (e.g., due to a WINDOW_UPDATE +/// overflow), `poll_capacity` and `poll_reset` must be notified. +/// Regression test for https://github.com/hyperium/h2/pull/897 +#[tokio::test] +async fn poll_capacity_woken_on_library_reset() { + h2_support::trace_init!(); + + for polling_capacity in [true, false] { + let (io, mut srv) = mock::new(); + let (client_done_tx, client_done_rx) = oneshot::channel::<()>(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + + // 2. Receive the 65535-byte initial window (4 DATA frames at default MAX_FRAME_SIZE). + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_383])).await; + + // 3. Grow stream window to 2^31-1, to set up for overflow later. + srv.send_frame(frames::window_update(0, 65535)).await; + srv.send_frame(frames::window_update(1, 2_147_483_647)) + .await; + + // 5. Receive the next 65535 bytes (connection-limited). + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_383])).await; + + // 6. Overflow: stream window 2147418112 + 65536 = 2^31 > 2^31-1. + srv.send_frame(frames::window_update(1, 65536)).await; + + // 8. Receive the RST_STREAM(FLOW_CONTROL_ERROR) sent by the library. + srv.recv_frame(frames::reset(1).flow_control()).await; + + // Wait for the client to finish. Otherwise Recv::recv_eof hides + // the missing waker. + let _ = client_done_rx.await; + }; + + let client = async move { + let (mut client, conn) = client::handshake(io).await.unwrap(); + tokio::spawn(async move { + // Separate task so the polled method won't resolve unless notify_send wakes it. + let _ = conn.await; + }); + + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (_resp, mut stream) = client.send_request(request, false).unwrap(); + + // 1. Exhaust the initial 65535-byte window. + stream.reserve_capacity(65535); + let cap = poll_fn(|cx| stream.poll_capacity(cx)) + .await + .unwrap() + .unwrap(); + assert_eq!(cap, 65535); + stream.send_data(vec![0u8; cap].into(), false).unwrap(); + + // 4. poll_capacity blocks until 3. replenishes windows, then send again. + stream.reserve_capacity(65535); + let cap = poll_fn(|cx| stream.poll_capacity(cx)) + .await + .unwrap() + .unwrap(); + assert_eq!(cap, 65535); + stream.send_data(vec![0u8; cap].into(), false).unwrap(); + + // 7. The polled method must be woken by the reset from 6. + if polling_capacity { + stream.reserve_capacity(65535); + let result = tokio::time::timeout( + Duration::from_secs(1), + poll_fn(|cx| stream.poll_capacity(cx)).wakened(), + ) + .await + .expect("poll_capacity was not woken"); + assert!(result.is_none()); + } else { + let reason = tokio::time::timeout( + Duration::from_secs(1), + poll_fn(|cx| stream.poll_reset(cx)).wakened(), + ) + .await + .expect("poll_reset was not woken") + .unwrap(); + assert_eq!(reason, Reason::FLOW_CONTROL_ERROR); + } + + let _ = client_done_tx.send(()); + }; + + join(srv, client).await; + } +}