Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/fd/delegate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
use alloc::sync::Arc;
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
use core::ffi::c_int;
use core::mem::MaybeUninit;

use delegate::delegate;
Expand Down Expand Up @@ -141,7 +143,7 @@ impl ObjectInterface for Fd {
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
async fn setsockopt(&self, _opt: SocketOption, _optval: bool) -> io::Result<()>;
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
async fn getsockopt(&self, _opt: SocketOption) -> io::Result<bool>;
async fn getsockopt(&self, _opt: SocketOption) -> io::Result<c_int>;
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
async fn getsockname(&self) -> io::Result<Option<Endpoint>>;
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
Expand Down
15 changes: 11 additions & 4 deletions src/fd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use alloc::sync::Arc;
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
use core::ffi::c_int;
use core::future;
use core::mem::MaybeUninit;
use core::pin::pin;
use core::task::Poll::{Pending, Ready};
use core::time::Duration;

#[cfg(any(feature = "net", feature = "virtio-vsock"))]
use num_enum::TryFromPrimitive;
#[cfg(feature = "net")]
use smoltcp::wire::{IpEndpoint, IpListenEndpoint};

Expand Down Expand Up @@ -43,10 +47,13 @@ pub(crate) enum ListenEndpoint {
Vsock(socket::vsock::VsockListenEndpoint),
}

#[allow(dead_code)]
#[derive(Debug, PartialEq)]
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
#[derive(TryFromPrimitive, PartialEq, Eq, Clone, Copy, Debug)]
#[repr(i32)]
pub(crate) enum SocketOption {
TcpNoDelay,
TcpNodelay = 1,
SoSndbuf = 0x1001,
SoRcvbuf = 0x1002,
}

pub(crate) type RawFd = i32;
Expand Down Expand Up @@ -263,7 +270,7 @@ pub(crate) trait ObjectInterface: Sync + Send {

/// `getsockopt` gets options on sockets
#[cfg(any(feature = "net", feature = "virtio-vsock"))]
async fn getsockopt(&self, _opt: SocketOption) -> io::Result<bool> {
async fn getsockopt(&self, _opt: SocketOption) -> io::Result<c_int> {
Err(Errno::Notsock)
}

Expand Down
24 changes: 14 additions & 10 deletions src/fd/socket/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::collections::BTreeSet;
use alloc::sync::Arc;
use core::ffi::c_int;
use core::future;
use core::sync::atomic::{AtomicU16, Ordering};
use core::task::Poll;
Expand All @@ -24,6 +25,9 @@ pub const SHUT_WR: i32 = 1;
pub const SHUT_RDWR: i32 = 2;
/// The default queue size for incoming connections
pub const DEFAULT_BACKLOG: i32 = 128;
/// The maximum queue size for incoming connections,
/// based on the default maximum used by modern Linux.
pub const SOMAXCONN: i32 = 4096;

fn get_ephemeral_port() -> u16 {
static LOCAL_ENDPOINT: AtomicU16 = AtomicU16::new(49152);
Expand Down Expand Up @@ -404,7 +408,7 @@ impl ObjectInterface for Socket {

self.is_listen = true;

for _ in 1..backlog {
for _ in 1..backlog.min(SOMAXCONN) {
let handle = nic.create_tcp_handle().unwrap();

let s = nic.get_mut_socket::<tcp::Socket<'_>>(handle);
Expand All @@ -418,7 +422,7 @@ impl ObjectInterface for Socket {
}

async fn setsockopt(&self, opt: SocketOption, optval: bool) -> io::Result<()> {
if opt == SocketOption::TcpNoDelay {
if opt == SocketOption::TcpNodelay {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();

Expand All @@ -433,15 +437,15 @@ impl ObjectInterface for Socket {
}
}

async fn getsockopt(&self, opt: SocketOption) -> io::Result<bool> {
if opt == SocketOption::TcpNoDelay {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());
async fn getsockopt(&self, opt: SocketOption) -> io::Result<c_int> {
let mut guard = NIC.lock();
let nic = guard.as_nic_mut().unwrap();
let socket = nic.get_mut_socket::<tcp::Socket<'_>>(*self.handle.first().unwrap());

Ok(socket.nagle_enabled())
} else {
Err(Errno::Inval)
match opt {
SocketOption::TcpNodelay => Ok(socket.nagle_enabled().into()),
SocketOption::SoSndbuf => Ok(c_int::try_from(socket.send_capacity()).unwrap()),
SocketOption::SoRcvbuf => Ok(c_int::try_from(socket.recv_capacity()).unwrap()),
}
}

Expand Down
17 changes: 16 additions & 1 deletion src/fd/socket/udp.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::ffi::c_int;
use core::future;
use core::mem::MaybeUninit;
use core::task::Poll;
Expand All @@ -9,7 +10,7 @@ use smoltcp::wire::{IpEndpoint, Ipv4Address, Ipv6Address};
use crate::errno::Errno;
use crate::executor::block_on;
use crate::executor::network::{Handle, NIC, wake_network_waker};
use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent};
use crate::fd::{self, Endpoint, ListenEndpoint, ObjectInterface, PollEvent, SocketOption};
use crate::io;
use crate::syscalls::socket::Af;

Expand Down Expand Up @@ -242,6 +243,20 @@ impl ObjectInterface for Socket {
async fn getsockname(&self) -> io::Result<Option<Endpoint>> {
Ok(Some(Endpoint::Ip(self.local_endpoint)))
}

async fn getsockopt(&self, opt: SocketOption) -> io::Result<c_int> {
let mut guard = NIC.lock();
let socket = guard
.as_nic_mut()
.unwrap()
.get_mut_socket::<udp::Socket<'_>>(self.handle);

match opt {
SocketOption::TcpNodelay => Err(Errno::Inval),
SocketOption::SoSndbuf => Ok(c_int::try_from(socket.payload_send_capacity()).unwrap()),
SocketOption::SoRcvbuf => Ok(c_int::try_from(socket.payload_recv_capacity()).unwrap()),
}
}
}

impl Drop for Socket {
Expand Down
42 changes: 16 additions & 26 deletions src/syscalls/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,9 @@ pub const SO_REUSEADDR: i32 = 0x0004;
pub const SO_KEEPALIVE: i32 = 0x0008;
pub const SO_BROADCAST: i32 = 0x0020;
pub const SO_LINGER: i32 = 0x0080;
pub const SO_SNDBUF: i32 = 0x1001;
pub const SO_RCVBUF: i32 = 0x1002;
pub const SO_SNDTIMEO: i32 = 0x1005;
pub const SO_RCVTIMEO: i32 = 0x1006;
pub const SO_ERROR: i32 = 0x1007;
pub const TCP_NODELAY: i32 = 1;
pub const MSG_PEEK: i32 = 1;
pub type sa_family_t = u8;
pub type socklen_t = u32;
Expand Down Expand Up @@ -625,8 +622,7 @@ pub extern "C" fn sys_socket(domain: i32, type_: i32, protocol: i32) -> i32 {
}

#[cfg(feature = "net")]
if (domain == Af::Inet || domain == Af::Inet6) && (sock == Sock::Stream || sock == Sock::Dgram)
{
if domain == Af::Inet && matches!(sock, Sock::Stream | Sock::Dgram) {
let mut guard = NIC.lock();

let NetworkState::Initialized(nic) = &mut *guard else {
Expand Down Expand Up @@ -932,10 +928,14 @@ pub unsafe extern "C" fn sys_setsockopt(
return -i32::from(Errno::Inval);
};

debug!("sys_setsockopt: {fd}, level {level:?}, optname {optname}");
let Ok(optname) = SocketOption::try_from(optname) else {
return -i32::from(Errno::Inval);
};

debug!("sys_setsockopt: {fd}, level {level:?}, optname {optname:?}");

if level == Ipproto::Tcp
&& optname == TCP_NODELAY
&& optname == SocketOption::TcpNodelay
&& optlen == u32::try_from(size_of::<i32>()).unwrap()
{
if optval.is_null() {
Expand All @@ -948,12 +948,7 @@ pub unsafe extern "C" fn sys_setsockopt(
|e| -i32::from(e),
|v| {
block_on(
async {
v.read()
.await
.setsockopt(SocketOption::TcpNoDelay, value != 0)
.await
},
async { v.read().await.setsockopt(optname, value != 0).await },
None,
)
.map_or_else(|e| -i32::from(e), |()| 0)
Expand All @@ -973,13 +968,16 @@ pub unsafe extern "C" fn sys_getsockopt(
optval: *mut c_void,
optlen: *mut socklen_t,
) -> i32 {
let Ok(Ok(level)) = u8::try_from(level).map(Ipproto::try_from) else {
let Ok(optname) = SocketOption::try_from(optname) else {
return -i32::from(Errno::Inval);
};

debug!("sys_getsockopt: {fd}, level {level:?}, optname {optname}");
debug!("sys_getsockopt: {fd}, level {level}, optname {optname:?}");

if level == Ipproto::Tcp && optname == TCP_NODELAY {
if level == Ipproto::Tcp as i32 && optname == SocketOption::TcpNodelay
|| level == SOL_SOCKET
&& (optname == SocketOption::SoSndbuf || optname == SocketOption::SoRcvbuf)
{
if optval.is_null() || optlen.is_null() {
return -i32::from(Errno::Inval);
}
Expand All @@ -990,18 +988,10 @@ pub unsafe extern "C" fn sys_getsockopt(
obj.map_or_else(
|e| -i32::from(e),
|v| {
block_on(
async { v.read().await.getsockopt(SocketOption::TcpNoDelay).await },
None,
)
.map_or_else(
block_on(async { v.read().await.getsockopt(optname).await }, None).map_or_else(
|e| -i32::from(e),
|value| {
if value {
*optval = 1;
} else {
*optval = 0;
}
*optval = value;
*optlen = size_of::<i32>().try_into().unwrap();

0
Expand Down
Loading