diff options
| author | Alejandro Soto <alejandro@34project.org> | 2022-01-04 06:49:48 -0600 |
|---|---|---|
| committer | Alejandro Soto <alejandro@34project.org> | 2022-01-04 06:49:52 -0600 |
| commit | 70baa472b2bee69f205cc1aada304d597b858005 (patch) | |
| tree | 7a0b1a0381b68fe0e091b87d00634ff13568bf6d /src/session.rs | |
| parent | 1955ec118a32d3fa174496abe5442c82c609273a (diff) | |
Move crate::fuse::* to the top-level
Diffstat (limited to 'src/session.rs')
| -rw-r--r-- | src/session.rs | 559 |
1 files changed, 559 insertions, 0 deletions
diff --git a/src/session.rs b/src/session.rs new file mode 100644 index 0000000..e83a8d4 --- /dev/null +++ b/src/session.rs @@ -0,0 +1,559 @@ +use std::{ + future::Future, + io, + marker::PhantomData, + ops::ControlFlow, + os::unix::io::{IntoRawFd, RawFd}, + path::PathBuf, + sync::{Arc, Mutex}, +}; + +use nix::{ + fcntl::{fcntl, FcntlArg, OFlag}, + sys::uio::{writev, IoVec}, + unistd::read, +}; + +use tokio::{ + io::unix::AsyncFd, + sync::{broadcast, OwnedSemaphorePermit, Semaphore}, +}; + +use bytemuck::bytes_of; +use smallvec::SmallVec; + +use crate::{ + mount::{unmount_sync, MountError}, + proto::{self, InHeader, Structured}, + util::{page_size, DumbFd, OutputChain}, + Errno, FuseError, FuseResult, +}; + +use super::{ + ops::{self, FromRequest}, + Done, Op, Operation, Reply, Request, +}; + +pub struct Start { + session_fd: DumbFd, + mountpoint: PathBuf, +} + +pub struct Session { + session_fd: AsyncFd<RawFd>, + interrupt_tx: broadcast::Sender<u64>, + buffers: Mutex<Vec<Buffer>>, + buffer_semaphore: Arc<Semaphore>, + buffer_pages: usize, + mountpoint: Mutex<Option<PathBuf>>, +} + +pub struct Endpoint<'a> { + session: &'a Arc<Session>, + local_buffer: Buffer, +} + +pub enum Dispatch<'o> { + Lookup(Incoming<'o, ops::Lookup>), + Forget(Incoming<'o, ops::Forget>), + Getattr(Incoming<'o, ops::Getattr>), + Readlink(Incoming<'o, ops::Readlink>), + Open(Incoming<'o, ops::Open>), + Read(Incoming<'o, ops::Read>), + Write(Incoming<'o, ops::Write>), + Statfs(Incoming<'o, ops::Statfs>), + Release(Incoming<'o, ops::Release>), + Setxattr(Incoming<'o, ops::Setxattr>), + Getxattr(Incoming<'o, ops::Getxattr>), + Listxattr(Incoming<'o, ops::Listxattr>), + Removexattr(Incoming<'o, ops::Removexattr>), + Flush(Incoming<'o, ops::Flush>), + Opendir(Incoming<'o, ops::Opendir>), + Readdir(Incoming<'o, ops::Readdir>), + Releasedir(Incoming<'o, ops::Releasedir>), + Access(Incoming<'o, ops::Access>), +} + +pub struct Incoming<'o, O: Operation<'o>> { + common: IncomingCommon<'o>, + _phantom: PhantomData<O>, +} + +pub struct Owned<O> { + session: Arc<Session>, + buffer: Buffer, + header: InHeader, + _permit: OwnedSemaphorePermit, + _phantom: PhantomData<O>, +} + +impl Session { + // Does not seem like 'a can be elided here + #[allow(clippy::needless_lifetimes)] + pub fn endpoint<'a>(self: &'a Arc<Self>) -> Endpoint<'a> { + Endpoint { + session: self, + local_buffer: Buffer::new(self.buffer_pages), + } + } + + pub fn unmount_sync(&self) -> Result<(), MountError> { + let mountpoint = self.mountpoint.lock().unwrap().take(); + if let Some(mountpoint) = &mountpoint { + unmount_sync(mountpoint)?; + } + + Ok(()) + } + + pub(crate) fn ok(&self, unique: u64, output: OutputChain<'_>) -> FuseResult<()> { + self.send(unique, 0, output) + } + + pub(crate) fn fail(&self, unique: u64, mut errno: i32) -> FuseResult<()> { + if errno <= 0 { + log::warn!( + "Attempted to fail req#{} with errno {} <= 0, coercing to ENOMSG", + unique, + errno + ); + + errno = Errno::ENOMSG as i32; + } + + self.send(unique, -errno, OutputChain::empty()) + } + + pub(crate) fn interrupt_rx(&self) -> broadcast::Receiver<u64> { + self.interrupt_tx.subscribe() + } + + async fn handshake<F>(&mut self, buffer: &mut Buffer, init: F) -> FuseResult<Handshake<F>> + where + F: FnOnce(Op<'_, ops::Init>) -> Done<'_>, + { + self.session_fd.readable().await?.retain_ready(); + let bytes = read(*self.session_fd.get_ref(), &mut buffer.0).map_err(io::Error::from)?; + + let (header, opcode) = InHeader::from_bytes(&buffer.0[..bytes])?; + let body = match opcode { + proto::Opcode::Init => { + <&proto::InitIn>::toplevel_from(&buffer.0[HEADER_END..bytes], &header)? + } + + _ => { + log::error!("First message from kernel is not Init, but {:?}", opcode); + return Err(FuseError::ProtocolInit); + } + }; + + use std::cmp::Ordering; + let supported = match body.major.cmp(&proto::MAJOR_VERSION) { + Ordering::Less => false, + Ordering::Equal => body.minor >= proto::REQUIRED_MINOR_VERSION, + Ordering::Greater => { + let tail = [bytes_of(&proto::MAJOR_VERSION)]; + self.ok(header.unique, OutputChain::tail(&tail))?; + + return Ok(Handshake::Restart(init)); + } + }; + + //TODO: fake some decency by supporting a few older minor versions + if !supported { + log::error!( + "Unsupported protocol {}.{}; this build requires \ + {major}.{}..={major}.{} (or a greater version \ + through compatibility)", + body.major, + body.minor, + proto::REQUIRED_MINOR_VERSION, + proto::TARGET_MINOR_VERSION, + major = proto::MAJOR_VERSION + ); + + self.fail(header.unique, Errno::EPROTONOSUPPORT as i32)?; + return Err(FuseError::ProtocolInit); + } + + let request = Request { header, body }; + let reply = Reply { + session: self, + unique: header.unique, + tail: ops::InitState { + kernel_flags: proto::InitFlags::from_bits_truncate(body.flags), + buffer_pages: self.buffer_pages, + }, + }; + + init((request, reply)).consume(); + Ok(Handshake::Done) + } + + fn send(&self, unique: u64, error: i32, output: OutputChain<'_>) -> FuseResult<()> { + let after_header: usize = output + .iter() + .flat_map(<[_]>::iter) + .copied() + .map(<[_]>::len) + .sum(); + + let length = (std::mem::size_of::<proto::OutHeader>() + after_header) as _; + let header = proto::OutHeader { + len: length, + error, + unique, + }; + + //TODO: Full const generics any time now? Fs::EXPECTED_REQUEST_SEGMENTS + let header = [bytes_of(&header)]; + let output = output.preceded(&header); + let buffers: SmallVec<[_; 8]> = output + .iter() + .flat_map(<[_]>::iter) + .copied() + .filter(|slice| !slice.is_empty()) + .map(IoVec::from_slice) + .collect(); + + let written = writev(*self.session_fd.get_ref(), &buffers).map_err(io::Error::from)?; + if written == length as usize { + Ok(()) + } else { + Err(FuseError::ShortWrite) + } + } +} + +impl Drop for Start { + fn drop(&mut self) { + if !self.mountpoint.as_os_str().is_empty() { + let _ = unmount_sync(&self.mountpoint); + } + } +} + +impl Drop for Session { + fn drop(&mut self) { + if let Some(mountpoint) = self.mountpoint.get_mut().unwrap().take() { + let _ = unmount_sync(&mountpoint); + } + + drop(DumbFd(*self.session_fd.get_ref())); // Close + } +} + +impl<'o> Dispatch<'o> { + pub fn op(self) -> Op<'o> { + use Dispatch::*; + + let common = match self { + Lookup(incoming) => incoming.common, + Forget(incoming) => incoming.common, + Getattr(incoming) => incoming.common, + Readlink(incoming) => incoming.common, + Open(incoming) => incoming.common, + Read(incoming) => incoming.common, + Write(incoming) => incoming.common, + Statfs(incoming) => incoming.common, + Release(incoming) => incoming.common, + Setxattr(incoming) => incoming.common, + Getxattr(incoming) => incoming.common, + Listxattr(incoming) => incoming.common, + Removexattr(incoming) => incoming.common, + Flush(incoming) => incoming.common, + Opendir(incoming) => incoming.common, + Readdir(incoming) => incoming.common, + Releasedir(incoming) => incoming.common, + Access(incoming) => incoming.common, + }; + + common.into_generic_op() + } +} + +impl Endpoint<'_> { + pub async fn receive<'o, F, Fut>(&'o mut self, dispatcher: F) -> FuseResult<ControlFlow<()>> + where + F: FnOnce(Dispatch<'o>) -> Fut, + Fut: Future<Output = Done<'o>>, + { + let buffer = &mut self.local_buffer.0; + let bytes = loop { + let session_fd = &self.session.session_fd; + + let mut readable = tokio::select! { + readable = session_fd.readable() => readable?, + + _ = session_fd.writable() => { + self.session.mountpoint.lock().unwrap().take(); + return Ok(ControlFlow::Break(())); + } + }; + + let mut read = |fd: &AsyncFd<RawFd>| read(*fd.get_ref(), buffer); + let result = match readable.try_io(|fd| read(fd).map_err(io::Error::from)) { + Ok(result) => result, + Err(_) => continue, + }; + + match result { + // Interrupted + //TODO: libfuse docs say that this has some side effects + Err(error) if error.kind() == std::io::ErrorKind::NotFound => continue, + + result => break result, + } + }; + + let (header, opcode) = InHeader::from_bytes(&buffer[..bytes?])?; + let common = IncomingCommon { + session: self.session, + buffer: &mut self.local_buffer, + header, + }; + + let dispatch = { + use proto::Opcode::*; + + macro_rules! dispatch { + ($op:ident) => { + Dispatch::$op(Incoming { + common, + _phantom: PhantomData, + }) + }; + } + + match opcode { + Destroy => return Ok(ControlFlow::Break(())), + + Lookup => dispatch!(Lookup), + Forget => dispatch!(Forget), + Getattr => dispatch!(Getattr), + Readlink => dispatch!(Readlink), + Open => dispatch!(Open), + Read => dispatch!(Read), + Write => dispatch!(Write), + Statfs => dispatch!(Statfs), + Release => dispatch!(Release), + Setxattr => dispatch!(Setxattr), + Getxattr => dispatch!(Getxattr), + Listxattr => dispatch!(Listxattr), + Removexattr => dispatch!(Removexattr), + Flush => dispatch!(Flush), + Opendir => dispatch!(Opendir), + Readdir => dispatch!(Readdir), + Releasedir => dispatch!(Releasedir), + Access => dispatch!(Access), + BatchForget => dispatch!(Forget), + ReaddirPlus => dispatch!(Readdir), + + _ => { + log::warn!("Not implemented: {}", common.header); + + let (_request, reply) = common.into_generic_op(); + reply.not_implemented().consume(); + + return Ok(ControlFlow::Continue(())); + } + } + }; + + dispatcher(dispatch).await.consume(); + Ok(ControlFlow::Continue(())) + } +} + +impl Start { + pub async fn start<F>(mut self, mut init: F) -> FuseResult<Arc<Session>> + where + F: FnOnce(Op<'_, ops::Init>) -> Done<'_>, + { + let mountpoint = std::mem::take(&mut self.mountpoint); + let session_fd = self.session_fd.take().into_raw_fd(); + + let flags = OFlag::O_NONBLOCK | OFlag::O_LARGEFILE; + fcntl(session_fd, FcntlArg::F_SETFL(flags)).unwrap(); + + let (interrupt_tx, _) = broadcast::channel(INTERRUPT_BROADCAST_CAPACITY); + + let buffer_pages = proto::MIN_READ_SIZE / page_size(); //TODO + let buffer_count = SHARED_BUFFERS; //TODO + let buffers = std::iter::repeat_with(|| Buffer::new(buffer_pages)) + .take(buffer_count) + .collect(); + + let mut session = Session { + session_fd: AsyncFd::with_interest(session_fd, tokio::io::Interest::READABLE)?, + interrupt_tx, + buffers: Mutex::new(buffers), + buffer_semaphore: Arc::new(Semaphore::new(buffer_count)), + buffer_pages, + mountpoint: Mutex::new(Some(mountpoint)), + }; + + let mut init_buffer = session.buffers.get_mut().unwrap().pop().unwrap(); + + loop { + init = match session.handshake(&mut init_buffer, init).await? { + Handshake::Restart(init) => init, + Handshake::Done => { + session.buffers.get_mut().unwrap().push(init_buffer); + break Ok(Arc::new(session)); + } + }; + } + } + + pub fn unmount_sync(mut self) -> Result<(), MountError> { + // This prevents Start::drop() from unmounting a second time + let mountpoint = std::mem::take(&mut self.mountpoint); + unmount_sync(&mountpoint) + } + + pub(crate) fn new(session_fd: DumbFd, mountpoint: PathBuf) -> Self { + Start { + session_fd, + mountpoint, + } + } +} + +impl<'o, O: Operation<'o>> Incoming<'o, O> +where + O::ReplyTail: FromRequest<'o, O>, +{ + pub fn op(self) -> Result<Op<'o, O>, Done<'o>> { + try_op( + self.common.session, + &self.common.buffer.0, + self.common.header, + ) + } + + pub async fn owned(self) -> (Done<'o>, Owned<O>) { + let session = self.common.session; + + let (buffer, permit) = { + let semaphore = Arc::clone(&session.buffer_semaphore); + let permit = semaphore + .acquire_owned() + .await + .expect("Buffer semaphore error"); + + let mut buffers = session.buffers.lock().unwrap(); + let buffer = buffers.pop().expect("Buffer semaphore out of sync"); + let buffer = std::mem::replace(self.common.buffer, buffer); + + (buffer, permit) + }; + + let owned = Owned { + session: Arc::clone(session), + buffer, + header: self.common.header, + _permit: permit, + _phantom: PhantomData, + }; + + (Done::new(), owned) + } +} + +impl<O: for<'o> Operation<'o>> Owned<O> +where + for<'o> <O as Operation<'o>>::ReplyTail: FromRequest<'o, O>, +{ + pub async fn op<'o, F, Fut>(&'o self, handler: F) + where + F: FnOnce(Op<'o, O>) -> Fut, + Fut: Future<Output = Done<'o>>, + { + match try_op(&self.session, &self.buffer.0, self.header) { + Ok(op) => handler(op).await.consume(), + Err(done) => done.consume(), + } + } +} + +impl<O> Drop for Owned<O> { + fn drop(&mut self) { + if let Ok(mut buffers) = self.session.buffers.lock() { + let empty = Buffer(Vec::new().into_boxed_slice()); + buffers.push(std::mem::replace(&mut self.buffer, empty)); + } + } +} + +const INTERRUPT_BROADCAST_CAPACITY: usize = 32; +const SHARED_BUFFERS: usize = 32; +const HEADER_END: usize = std::mem::size_of::<InHeader>(); + +struct IncomingCommon<'o> { + session: &'o Arc<Session>, + buffer: &'o mut Buffer, + header: InHeader, +} + +enum Handshake<F> { + Done, + Restart(F), +} + +struct Buffer(Box<[u8]>); + +impl<'o> IncomingCommon<'o> { + fn into_generic_op(self) -> Op<'o> { + let request = Request { + header: self.header, + body: (), + }; + + let reply = Reply { + session: self.session, + unique: self.header.unique, + tail: (), + }; + + (request, reply) + } +} + +impl Buffer { + fn new(pages: usize) -> Self { + Buffer(vec![0; pages * page_size()].into_boxed_slice()) + } +} + +fn try_op<'o, O: Operation<'o>>( + session: &'o Session, + bytes: &'o [u8], + header: InHeader, +) -> Result<Op<'o, O>, Done<'o>> +where + O::ReplyTail: FromRequest<'o, O>, +{ + let body = match Structured::toplevel_from(&bytes[HEADER_END..header.len as usize], &header) { + Ok(body) => body, + Err(error) => { + log::error!("Parsing request {}: {}", header, error); + let reply = Reply::<ops::Any> { + session, + unique: header.unique, + tail: (), + }; + + return Err(reply.io_error()); + } + }; + + let request = Request { header, body }; + let reply = Reply { + session, + unique: header.unique, + tail: FromRequest::from_request(&request), + }; + + Ok((request, reply)) +} |
