Skip to content

Commit 8f161e9

Browse files
committed
perf: improve wind-tuic perfomance
1 parent 60cc5ba commit 8f161e9

4 files changed

Lines changed: 196 additions & 78 deletions

File tree

crates/wind-tuic/src/proto/mod.rs

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use std::future::Future;
1515

1616
use bytes::{Buf, BytesMut};
1717
use eyre::eyre;
18-
use tokio_util::codec::{Decoder, Encoder};
18+
use tokio_util::codec::Encoder;
1919
pub use udp_stream::*;
2020
use wind_core::{io::quinn::QuinnCompat, tcp::AbstractTcpStream, types::TargetAddr};
2121

@@ -26,24 +26,114 @@ pub type Error = eyre::Report;
2626
pub const VER: u8 = 5;
2727

2828
/// Helper function to decode header with better error reporting
29-
pub fn decode_header(buf: &mut BytesMut, context: &str) -> Result<Header, Error> {
30-
HeaderCodec
31-
.decode(buf)?
32-
.ok_or_else(|| eyre!("Incomplete header in {}", context))
29+
pub fn decode_header(buf: &mut impl Buf, context: &str) -> Result<Header, Error> {
30+
if buf.remaining() < 2 {
31+
return Err(eyre!("Incomplete header in {}", context));
32+
}
33+
let ver = buf.get_u8();
34+
if ver != VER {
35+
return Err(eyre!("Version mismatch: expected {}, got {}", VER, ver));
36+
}
37+
let cmd = CmdType::from(buf.get_u8());
38+
if matches!(cmd, CmdType::Other(_)) {
39+
return Err(eyre!("Unknown command type: {}", u8::from(cmd)));
40+
}
41+
Ok(Header::new(cmd))
3342
}
3443

3544
/// Helper function to decode command with better error reporting
36-
pub fn decode_command(cmd_type: CmdType, buf: &mut BytesMut, context: &str) -> Result<Command, Error> {
37-
CmdCodec(cmd_type)
38-
.decode(buf)?
39-
.ok_or_else(|| eyre!("Incomplete command in {}", context))
45+
pub fn decode_command(cmd_type: CmdType, buf: &mut impl Buf, context: &str) -> Result<Command, Error> {
46+
match cmd_type {
47+
CmdType::Auth => {
48+
if buf.remaining() < 16 + 32 {
49+
return Err(eyre!("Incomplete auth command in {}", context));
50+
}
51+
let mut uuid = [0; 16];
52+
buf.copy_to_slice(&mut uuid);
53+
let mut token = [0; 32];
54+
buf.copy_to_slice(&mut token);
55+
Ok(Command::Auth {
56+
uuid: uuid::Uuid::from_bytes(uuid),
57+
token,
58+
})
59+
}
60+
CmdType::Connect => Ok(Command::Connect),
61+
CmdType::Packet => {
62+
if buf.remaining() < 8 {
63+
return Err(eyre!("Incomplete packet command in {}", context));
64+
}
65+
Ok(Command::Packet {
66+
assoc_id: buf.get_u16(),
67+
pkt_id: buf.get_u16(),
68+
frag_total: buf.get_u8(),
69+
frag_id: buf.get_u8(),
70+
size: buf.get_u16(),
71+
})
72+
}
73+
CmdType::Dissociate => {
74+
if buf.remaining() < 2 {
75+
return Err(eyre!("Incomplete dissociate command in {}", context));
76+
}
77+
Ok(Command::Dissociate {
78+
assoc_id: buf.get_u16(),
79+
})
80+
}
81+
CmdType::Heartbeat => Ok(Command::Heartbeat),
82+
CmdType::Other(v) => Err(eyre!("Unknown command type: {}", v)),
83+
}
4084
}
4185

4286
/// Helper function to decode address with better error reporting
43-
pub fn decode_address(buf: &mut BytesMut, context: &str) -> Result<Address, Error> {
44-
AddressCodec
45-
.decode(buf)?
46-
.ok_or_else(|| eyre!("Incomplete address in {}", context))
87+
pub fn decode_address(buf: &mut impl Buf, context: &str) -> Result<Address, Error> {
88+
if !buf.has_remaining() {
89+
return Err(eyre!("Incomplete address in {}", context));
90+
}
91+
let addr_type = AddressType::from(buf.chunk()[0]);
92+
93+
match addr_type {
94+
AddressType::None => {
95+
buf.advance(1);
96+
Ok(Address::None)
97+
}
98+
AddressType::IPv4 => {
99+
if buf.remaining() < 1 + 4 + 2 {
100+
return Err(eyre!("Incomplete IPv4 address in {}", context));
101+
}
102+
buf.advance(1);
103+
let mut octets = [0; 4];
104+
buf.copy_to_slice(&mut octets);
105+
let ip = std::net::Ipv4Addr::from(octets);
106+
let port = buf.get_u16();
107+
Ok(Address::IPv4(ip, port))
108+
}
109+
AddressType::IPv6 => {
110+
if buf.remaining() < 1 + 16 + 2 {
111+
return Err(eyre!("Incomplete IPv6 address in {}", context));
112+
}
113+
buf.advance(1);
114+
let mut octets = [0; 16];
115+
buf.copy_to_slice(&mut octets);
116+
let ip = std::net::Ipv6Addr::from(octets);
117+
let port = buf.get_u16();
118+
Ok(Address::IPv6(ip, port))
119+
}
120+
AddressType::Domain => {
121+
if buf.remaining() < 2 {
122+
return Err(eyre!("Incomplete Domain address in {}", context));
123+
}
124+
let len = buf.chunk()[1] as usize;
125+
if buf.remaining() < 1 + 1 + len + 2 {
126+
return Err(eyre!("Incomplete Domain address in {}", context));
127+
}
128+
buf.advance(2);
129+
let mut domain = vec![0; len];
130+
buf.copy_to_slice(&mut domain);
131+
let s = String::from_utf8(domain).map_err(|_| eyre!("Invalid UTF-8 domain in {}", context))?;
132+
let port = buf.get_u16();
133+
Ok(Address::Domain(s, port))
134+
}
135+
AddressType::Other(v) => Err(eyre!("Unknown address type: {}", v)),
136+
}
47137
}
48138

49139
/// Helper function to convert Address to TargetAddr

crates/wind-tuic/src/proto/udp_stream.rs

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -83,40 +83,44 @@ impl FragmentReassemblyBuffer {
8383
let target_clone = target.clone();
8484

8585
// Get or create the fragment metadata
86-
let meta = self
87-
.fragments
88-
.entry(key)
89-
.or_insert_with(async {
90-
Arc::new(FragmentMetadata {
91-
frag_total,
92-
fragments: Cache::new(frag_total.into()),
93-
last_updated: AtomicU64::new(init_time().elapsed().as_secs()),
94-
source: ArcSwapOption::new(source.clone().map(Arc::new)),
95-
target: ArcSwap::new(Arc::new(target)),
86+
let is_complete = {
87+
let meta = self
88+
.fragments
89+
.entry(key)
90+
.or_insert_with(async {
91+
Arc::new(FragmentMetadata {
92+
frag_total,
93+
fragments: Cache::new(frag_total.into()),
94+
last_updated: AtomicU64::new(init_time().elapsed().as_secs()),
95+
source: ArcSwapOption::new(source.clone().map(Arc::new)),
96+
target: ArcSwap::new(Arc::new(target)),
97+
})
9698
})
97-
})
98-
.await;
99+
.await;
99100

100-
// If this is the first fragment (frag_id == 0) and it has a real address,
101-
// update the target address in case we received other fragments first with
102-
// placeholder addresses
103-
if frag_id == 0 && !is_placeholder_addr {
104-
meta.value().target.store(Arc::new(target_clone));
105-
}
101+
// If this is the first fragment (frag_id == 0) and it has a real address,
102+
// update the target address in case we received other fragments first with
103+
// placeholder addresses
104+
if frag_id == 0 && !is_placeholder_addr {
105+
meta.value().target.store(Arc::new(target_clone));
106+
}
106107

107-
// Update timestamp
108-
meta.value()
109-
.last_updated
110-
.store(init_time().elapsed().as_secs(), Ordering::Relaxed);
108+
// Update timestamp
109+
meta.value()
110+
.last_updated
111+
.store(init_time().elapsed().as_secs(), Ordering::Relaxed);
111112

112-
// Store this fragment
113-
meta.value().fragments.insert(frag_id, payload).await;
113+
// Store this fragment
114+
meta.value().fragments.insert(frag_id, payload).await;
114115

115-
// Ensure all pending cache operations are completed
116-
meta.value().fragments.run_pending_tasks().await;
116+
// Ensure all pending cache operations are completed
117+
meta.value().fragments.run_pending_tasks().await;
117118

118-
// Check if all fragments have been received
119-
if meta.value().fragments.entry_count() == meta.value().frag_total as u64 {
119+
// Check if all fragments have been received
120+
meta.value().fragments.entry_count() == meta.value().frag_total as u64
121+
};
122+
123+
if is_complete {
120124
// All fragments received, reassemble the packet
121125
return self.reassemble_packet(key).await;
122126
}
@@ -158,14 +162,32 @@ impl FragmentReassemblyBuffer {
158162
}
159163

160164
// Return the reassembled packet
161-
let source = meta.source.load().as_ref().map(|arc| (**arc).clone());
162-
let target = (**meta.target.load()).clone();
163-
164-
Some(UdpPacket {
165-
source,
166-
target,
167-
payload: buffer.freeze(),
168-
})
165+
let payload = buffer.freeze();
166+
match Arc::try_unwrap(meta) {
167+
Ok(m) => {
168+
let source = m
169+
.source
170+
.into_inner()
171+
.map(|arc| Arc::try_unwrap(arc).unwrap_or_else(|a| (*a).clone()));
172+
let target = Arc::try_unwrap(m.target.into_inner()).unwrap_or_else(|a| (*a).clone());
173+
174+
Some(UdpPacket {
175+
source,
176+
target,
177+
payload,
178+
})
179+
}
180+
Err(arc) => {
181+
let source = arc.source.load().as_ref().map(|a| (**a).clone());
182+
let target = (**arc.target.load()).clone();
183+
184+
Some(UdpPacket {
185+
source,
186+
target,
187+
payload,
188+
})
189+
}
190+
}
169191
} else {
170192
None
171193
}

crates/wind-tuic/src/quinn/inbound.rs

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::{
1111
time::Duration,
1212
};
1313

14-
use bytes::BytesMut;
14+
use bytes::Bytes;
1515
use eyre::{Context, ContextCompat};
1616
use quinn::{Endpoint, EndpointConfig, IdleTimeout, ServerConfig, TokioRuntime, TransportConfig, VarInt};
1717
use rustls::{
@@ -383,7 +383,7 @@ async fn handle_uni_stream<C: InboundCallback>(
383383
.read_to_end(65536)
384384
.await
385385
.map_err(|e| eyre::eyre!("Failed to read stream: {}", e))?;
386-
let mut buf = BytesMut::from(&data[..]);
386+
let mut buf = bytes::Bytes::from(data);
387387

388388
// Decode header and command using helper functions
389389
let header = crate::proto::decode_header(&mut buf, "uni stream")?;
@@ -402,7 +402,7 @@ async fn handle_uni_stream<C: InboundCallback>(
402402
} => {
403403
// Decode address (may be Address::None for non-first fragments)
404404
let addr = crate::proto::decode_address(&mut buf, "uni stream packet")?;
405-
let payload = buf.split_to(size as usize).freeze();
405+
let payload = buf.split_to(size as usize);
406406

407407
// Convert address to TargetAddr, using placeholder for non-first fragments
408408
let target_addr = match crate::proto::address_to_target(addr) {
@@ -448,18 +448,18 @@ async fn handle_bi_stream<C: InboundCallback>(
448448
}
449449

450450
// Read header and command
451-
let mut header_buf = vec![0u8; 2];
451+
let mut header_buf = [0u8; 2];
452452
recv.read_exact(&mut header_buf)
453453
.await
454454
.map_err(|e| eyre::eyre!("Failed to read header: {}", e))?;
455-
let mut buf = BytesMut::from(&header_buf[..]);
455+
let mut buf = &header_buf[..];
456456

457457
let header = crate::proto::decode_header(&mut buf, "bi stream")?;
458458

459459
match header.command {
460460
CmdType::Connect => {
461461
// Decode command (Connect has no additional fields)
462-
let _cmd = crate::proto::decode_command(CmdType::Connect, &mut BytesMut::new(), "bi stream")?;
462+
let _cmd = crate::proto::decode_command(CmdType::Connect, &mut [].as_ref(), "bi stream")?;
463463

464464
// Read exactly the address bytes, leaving the relay payload in `recv`
465465
// so that the same stream can be used for bidirectional data relay.
@@ -501,7 +501,7 @@ async fn handle_datagram<C: InboundCallback>(connection: Arc<InboundCtx>, data:
501501
}
502502
}
503503

504-
let mut buf = BytesMut::from(data.as_ref());
504+
let mut buf = data;
505505

506506
// Decode header using helper function
507507
let header = crate::proto::decode_header(&mut buf, "datagram")?;
@@ -519,7 +519,7 @@ async fn handle_datagram<C: InboundCallback>(connection: Arc<InboundCtx>, data:
519519
} = cmd
520520
{
521521
let addr = crate::proto::decode_address(&mut buf, "datagram packet")?;
522-
let payload = buf.split_to(size as usize).freeze();
522+
let payload = buf.split_to(size as usize);
523523

524524
// Convert address to TargetAddr, using placeholder for non-first fragments
525525
let target_addr = match crate::proto::address_to_target(addr) {
@@ -721,49 +721,57 @@ async fn read_address_exact(recv: &mut quinn::RecvStream) -> eyre::Result<crate:
721721
.await
722722
.map_err(|e| eyre::eyre!("Failed to read address type: {}", e))?;
723723

724-
let mut buf = BytesMut::with_capacity(20);
725-
buf.extend_from_slice(&type_byte);
726-
727724
match type_byte[0] {
728-
0xFF => {
729-
// AddressType::None — just the single type byte
730-
}
725+
0xFF => Ok(crate::proto::Address::None),
731726
0x01 => {
732727
// AddressType::IPv4 — 4-byte address + 2-byte port
733728
let mut rest = [0u8; 6];
734729
recv.read_exact(&mut rest)
735730
.await
736731
.map_err(|e| eyre::eyre!("Failed to read IPv4 address: {}", e))?;
737-
buf.extend_from_slice(&rest);
732+
let mut ip_bytes = [0u8; 4];
733+
ip_bytes.copy_from_slice(&rest[0..4]);
734+
let port = u16::from_be_bytes([rest[4], rest[5]]);
735+
Ok(crate::proto::Address::IPv4(std::net::Ipv4Addr::from(ip_bytes), port))
738736
}
739737
0x02 => {
740738
// AddressType::IPv6 — 16-byte address + 2-byte port
741739
let mut rest = [0u8; 18];
742740
recv.read_exact(&mut rest)
743741
.await
744742
.map_err(|e| eyre::eyre!("Failed to read IPv6 address: {}", e))?;
745-
buf.extend_from_slice(&rest);
743+
let mut ip_bytes = [0u8; 16];
744+
ip_bytes.copy_from_slice(&rest[0..16]);
745+
let port = u16::from_be_bytes([rest[16], rest[17]]);
746+
Ok(crate::proto::Address::IPv6(std::net::Ipv6Addr::from(ip_bytes), port))
746747
}
747748
0x00 => {
748749
// AddressType::Domain — 1-byte length + <length> bytes + 2-byte port
749750
let mut len_byte = [0u8; 1];
750751
recv.read_exact(&mut len_byte)
751752
.await
752753
.map_err(|e| eyre::eyre!("Failed to read domain length: {}", e))?;
753-
buf.extend_from_slice(&len_byte);
754754
let domain_len = len_byte[0] as usize;
755-
let mut rest = vec![0u8; domain_len + 2];
756-
recv.read_exact(&mut rest)
755+
756+
// Max domain len is 255. Use a stack buffer.
757+
let mut domain_buf = [0u8; 255];
758+
let domain_slice = &mut domain_buf[..domain_len];
759+
recv.read_exact(domain_slice)
757760
.await
758761
.map_err(|e| eyre::eyre!("Failed to read domain address: {}", e))?;
759-
buf.extend_from_slice(&rest);
760-
}
761-
t => {
762-
return Err(eyre::eyre!("Unknown address type byte 0x{:02x}", t));
762+
763+
let mut port_buf = [0u8; 2];
764+
recv.read_exact(&mut port_buf)
765+
.await
766+
.map_err(|e| eyre::eyre!("Failed to read domain port: {}", e))?;
767+
let port = u16::from_be_bytes(port_buf);
768+
769+
let domain_str = String::from_utf8(domain_slice.to_vec())
770+
.map_err(|_| eyre::eyre!("Invalid UTF-8 domain address"))?;
771+
Ok(crate::proto::Address::Domain(domain_str, port))
763772
}
773+
t => Err(eyre::eyre!("Unknown address type byte 0x{:02x}", t)),
764774
}
765-
766-
crate::proto::decode_address(&mut buf, "bi stream connect")
767775
}
768776

769777
/// Handle UDP dissociate

0 commit comments

Comments
 (0)