diff --git a/util/src/fd.rs b/util/src/fd.rs index f3b9e06..0c59db4 100644 --- a/util/src/fd.rs +++ b/util/src/fd.rs @@ -1,3 +1,4 @@ +/// Utilities for working with file descriptors use anyhow::bail; use rustix::{ fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}, @@ -8,10 +9,37 @@ use crate::{mem::Forgetting, result::OkExt}; /// Prepare a file descriptor for use in Rust code. /// - /// Checks if the file descriptor is valid and duplicates it to a new file descriptor. /// The old file descriptor is masked to avoid potential use after free (on file descriptor) /// in case the given file descriptor is still used somewhere +/// +/// # Panic and safety +/// +/// Will panic if the given file descriptor is negative of or larger than +/// the file descriptor numbers permitted by the operating system. +/// +/// # Examples +/// +/// ``` +/// use std::io::Write; +/// use std::os::fd::{IntoRawFd, AsRawFd}; +/// use tempfile::tempdir; +/// use rosenpass_util::fd::{claim_fd, FdIo}; +/// +/// // Open a file and turn it into a raw file descriptor +/// let orig = tempfile::tempfile()?.into_raw_fd(); +/// +/// // Reclaim that file and ready it for reading +/// let mut claimed = FdIo(claim_fd(orig)?); +/// +/// // A different file descriptor is used +/// assert!(orig.as_raw_fd() != claimed.0.as_raw_fd()); +/// +/// // Write some data +/// claimed.write_all(b"Hello, World!")?; +/// +/// Ok::<(), std::io::Error>(()) +/// ``` pub fn claim_fd(fd: RawFd) -> rustix::io::Result { let new = clone_fd_cloexec(unsafe { BorrowedFd::borrow_raw(fd) })?; mask_fd(fd)?; @@ -22,7 +50,32 @@ pub fn claim_fd(fd: RawFd) -> rustix::io::Result { /// /// Checks if the file descriptor is valid. /// -/// Unlike [claim_fd], this will reuse the same file descriptor identifier instead of masking it. +/// Unlike [claim_fd], this will try to reuse the same file descriptor identifier instead of masking it. +/// +/// # Panic and safety +/// +/// Will panic if the given file descriptor is negative of or larger than +/// the file descriptor numbers permitted by the operating system. +/// +/// # Examples +/// +/// ``` +/// use std::io::Write; +/// use std::os::fd::IntoRawFd; +/// use tempfile::tempdir; +/// use rosenpass_util::fd::{claim_fd_inplace, FdIo}; +/// +/// // Open a file and turn it into a raw file descriptor +/// let fd = tempfile::tempfile()?.into_raw_fd(); +/// +/// // Reclaim that file and ready it for reading +/// let mut fd = FdIo(claim_fd_inplace(fd)?); +/// +/// // Write some data +/// fd.write_all(b"Hello, World!")?; +/// +/// Ok::<(), std::io::Error>(()) +/// ``` pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result { let mut new = unsafe { OwnedFd::from_raw_fd(fd) }; let tmp = clone_fd_cloexec(&new)?; @@ -30,6 +83,13 @@ pub fn claim_fd_inplace(fd: RawFd) -> rustix::io::Result { Ok(new) } +/// Will close the given file descriptor and overwrite +/// it with a masking file descriptor (see [open_nullfd]) to prevent accidental reuse. +/// +/// # Panic and safety +/// +/// Will panic if the given file descriptor is negative of or larger than +/// the file descriptor numbers permitted by the operating system. pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> { // Safety: because the OwnedFd resulting from OwnedFd::from_raw_fd is wrapped in a Forgetting, // it never gets dropped, meaning that fd is never closed and thus outlives the OwnedFd @@ -37,11 +97,17 @@ pub fn mask_fd(fd: RawFd) -> rustix::io::Result<()> { clone_fd_to_cloexec(open_nullfd()?, &mut owned) } +/// Duplicate a file descriptor, setting the close on exec flag pub fn clone_fd_cloexec(fd: Fd) -> rustix::io::Result { - const MINFD: RawFd = 3; // Avoid stdin, stdout, and stderr + /// Avoid stdin, stdout, and stderr + const MINFD: RawFd = 3; fcntl_dupfd_cloexec(fd, MINFD) } +/// Duplicate a file descriptor, setting the close on exec flag. +/// +/// This is slightly different from [clone_fd_cloexec], as this function supports specifying an +/// explicit destination file descriptor. #[cfg(target_os = "linux")] pub fn clone_fd_to_cloexec(fd: Fd, new: &mut OwnedFd) -> rustix::io::Result<()> { use rustix::io::{dup3, DupFlags}; @@ -56,7 +122,21 @@ pub fn clone_fd_to_cloexec(fd: Fd, new: &mut OwnedFd) -> rustix::io::R } /// Open a "blocked" file descriptor. I.e. a file descriptor that is neither meant for reading nor -/// writing +/// writing. +/// +/// # Safety +/// +/// The behavior of the file descriptor when being written to or from is undefined. +/// +/// # Examples +/// +/// ``` +/// use std::{fs::File, io::Write, os::fd::IntoRawFd}; +/// use rustix::fd::FromRawFd; +/// use rosenpass_util::fd::open_nullfd; +/// +/// let nullfd = open_nullfd().unwrap(); +/// ``` pub fn open_nullfd() -> rustix::io::Result { use rustix::fs::{open, Mode, OFlags}; // TODO: Add tests showing that this will throw errors on use @@ -64,8 +144,24 @@ pub fn open_nullfd() -> rustix::io::Result { } /// Convert low level errors into std::io::Error +/// +/// # Examples +/// +/// ``` +/// use std::io::ErrorKind as EK; +/// use rustix::io::Errno; +/// use rosenpass_util::fd::IntoStdioErr; +/// +/// let e = Errno::INTR.into_stdio_err(); +/// assert!(matches!(e.kind(), EK::Interrupted)); +/// +/// let r : rustix::io::Result<()> = Err(Errno::INTR); +/// assert!(matches!(r, Err(e) if e.kind() == EK::Interrupted)); +/// ``` pub trait IntoStdioErr { + /// Target type produced (e.g. std::io:Error or std::io::Result depending on context type Target; + /// Convert low level errors to fn into_stdio_err(self) -> Self::Target; } @@ -86,6 +182,10 @@ impl IntoStdioErr for rustix::io::Result { } /// Read and write directly from a file descriptor +/// +/// # Examples +/// +/// See [claim_fd]. pub struct FdIo(pub Fd); impl std::io::Read for FdIo { @@ -104,7 +204,17 @@ impl std::io::Write for FdIo { } } +/// Helpers for accessing stat(2) information pub trait StatExt { + /// Check if the file is a socket + /// + /// # Examples + /// + /// ``` + /// use rosenpass_util::fd::StatExt; + /// assert!(rustix::fs::stat("/")?.is_socket() == false); + /// Ok::<(), rustix::io::Errno>(()) + /// ```` fn is_socket(&self) -> bool; } @@ -116,8 +226,21 @@ impl StatExt for rustix::fs::Stat { } } +/// Helpers for accessing stat(2) information on an open file descriptor pub trait TryStatExt { + /// Error type returned by operations type Error; + + /// Check if the file is a socket + /// + /// # Examples + /// + /// ``` + /// use rosenpass_util::fd::TryStatExt; + /// let fd = rustix::fs::open("/", rustix::fs::OFlags::empty(), rustix::fs::Mode::empty())?; + /// assert!(matches!(fd.is_socket(), Ok(false))); + /// Ok::<(), rustix::io::Errno>(()) + /// ```` fn is_socket(&self) -> Result; } @@ -132,13 +255,18 @@ where } } +/// Determine the type of socket a file descriptor represents pub trait GetSocketType { + /// Error type returned by operations in this trait type Error; + /// Look up the socket; see [rustix::net::sockopt::get_socket_type] fn socket_type(&self) -> Result; + /// Checks if the socket is a datagram socket fn is_datagram_socket(&self) -> Result { use rustix::net::SocketType; matches!(self.socket_type()?, SocketType::DGRAM).ok() } + /// Checks if the socket is a stream socket fn is_stream_socket(&self) -> Result { Ok(self.socket_type()? == rustix::net::SocketType::STREAM) } @@ -155,13 +283,18 @@ where } } +/// Distinguish different socket address familys; e.g. IP and unix sockets #[cfg(target_os = "linux")] pub trait GetSocketDomain { + /// Error type returned by operations in this trait type Error; + /// Retrieve the socket domain (address family) fn socket_domain(&self) -> Result; + /// Alias for [socket_domain] fn socket_address_family(&self) -> Result { self.socket_domain() } + /// Check if the underlying socket is a unix domain socket fn is_unix_socket(&self) -> Result { Ok(self.socket_domain()? == rustix::net::AddressFamily::UNIX) } @@ -179,10 +312,14 @@ where } } +/// Distinguish different types of unix sockets #[cfg(target_os = "linux")] pub trait GetUnixSocketType { + /// Error type returned by operations in this trait type Error; + /// Check if the socket is a unix stream socket fn is_unix_stream_socket(&self) -> Result; + /// Returns Ok(()) only if the underlying socket is a unix stream socket fn demand_unix_stream_socket(&self) -> anyhow::Result<()>; } @@ -210,14 +347,18 @@ where } #[cfg(target_os = "linux")] +/// Distinguish between different network socket protocols (e.g. tcp, udp) pub trait GetSocketProtocol { + /// Retrieve the socket protocol fn socket_protocol(&self) -> Result, rustix::io::Errno>; + /// Check if the socket is a udp socket fn is_udp_socket(&self) -> Result { self.socket_protocol()? .map(|p| p == rustix::net::ipproto::UDP) .unwrap_or(false) .ok() } + /// Return Ok(()) only if the socket is a udp socket fn demand_udp_socket(&self) -> anyhow::Result<()> { match self.socket_protocol() { Ok(Some(rustix::net::ipproto::UDP)) => Ok(()), @@ -243,58 +384,50 @@ where #[cfg(test)] mod tests { use super::*; - use std::fs::{read_to_string, File}; use std::io::{Read, Write}; - use std::os::fd::IntoRawFd; - use tempfile::tempdir; - - #[test] - fn test_claim_fd() { - let tmp_dir = tempdir().unwrap(); - let path = tmp_dir.path().join("test"); - let file = File::create(path.clone()).unwrap(); - let fd: RawFd = file.into_raw_fd(); - let owned_fd = claim_fd(fd).unwrap(); - let mut file = unsafe { File::from_raw_fd(owned_fd.into_raw_fd()) }; - file.write_all(b"Hello, World!").unwrap(); - - let message = read_to_string(path).unwrap(); - assert_eq!(message, "Hello, World!"); - } #[test] #[should_panic(expected = "fd != u32::MAX as RawFd")] fn test_claim_fd_invalid_neg() { - let fd: RawFd = -1; - let _ = claim_fd(fd); + let _ = claim_fd(-1); } #[test] #[should_panic(expected = "fd != u32::MAX as RawFd")] fn test_claim_fd_invalid_max() { - let fd: RawFd = i64::MAX as RawFd; - let _ = claim_fd(fd); + let _ = claim_fd(i64::MAX as RawFd); } #[test] - fn test_open_nullfd_write() { - let nullfd = open_nullfd().unwrap(); - let mut file = unsafe { File::from_raw_fd(nullfd.into_raw_fd()) }; - let res = file.write_all(b"Hello, World!"); - assert!(res.is_err()); - assert_eq!( - res.unwrap_err().to_string(), - "Bad file descriptor (os error 9)" - ); + #[should_panic] + fn test_claim_fd_inplace_invalid_neg() { + let _ = claim_fd_inplace(-1); } #[test] - fn test_open_nullfd_read() { - let nullfd = open_nullfd().unwrap(); - let mut file = unsafe { File::from_raw_fd(nullfd.into_raw_fd()) }; - let mut buffer = [0; 10]; - let res = file.read_exact(&mut buffer); - assert!(res.is_err()); - assert_eq!(res.unwrap_err().to_string(), "failed to fill whole buffer"); + #[should_panic] + fn test_claim_fd_inplace_invalid_max() { + let _ = claim_fd_inplace(i64::MAX as RawFd); + } + + #[test] + #[should_panic] + fn test_mask_fd_invalid_neg() { + let _ = mask_fd(-1); + } + + #[test] + #[should_panic] + fn test_mask_fd_invalid_max() { + let _ = mask_fd(i64::MAX as RawFd); + } + + #[test] + fn test_open_nullfd() -> anyhow::Result<()> { + let mut file = FdIo(open_nullfd()?); + let mut buf = [0; 10]; + assert!(matches!(file.read(&mut buf), Ok(0) | Err(_))); + assert!(matches!(file.write(&buf), Err(_))); + Ok(()) } }