summaryrefslogtreecommitdiff
path: root/src/fuse/session.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/fuse/session.rs108
1 files changed, 49 insertions, 59 deletions
diff --git a/src/fuse/session.rs b/src/fuse/session.rs
index 5045099..e947bef 100644
--- a/src/fuse/session.rs
+++ b/src/fuse/session.rs
@@ -1,5 +1,4 @@
use std::{
- convert::TryInto,
future::Future,
io,
marker::PhantomData,
@@ -10,7 +9,7 @@ use std::{
use nix::{
fcntl::{fcntl, FcntlArg, OFlag},
sys::uio::{writev, IoVec},
- unistd::{read, sysconf, SysconfVar},
+ unistd::read,
};
use tokio::{
@@ -23,7 +22,7 @@ use smallvec::SmallVec;
use crate::{
proto::{self, InHeader, Structured},
- util::{DumbFd, OutputChain},
+ util::{page_size, DumbFd, OutputChain},
Errno, FuseError, FuseResult,
};
@@ -104,12 +103,15 @@ impl Session {
self.interrupt_tx.subscribe()
}
- async fn handshake(&mut self, buffer: &mut Buffer) -> FuseResult<Handshake> {
+ async fn handshake<F>(&mut self, buffer: &mut Buffer, init: F) -> FuseResult<Handshake<F>>
+ where
+ F: FnOnce(Op<'_, ops::Init>) -> Done<'_>,
+ {
self.session_fd.readable().await?.retain_ready();
let bytes = read(*self.session_fd.get_ref(), &mut buffer.0).map_err(io::Error::from)?;
let (header, opcode) = InHeader::from_bytes(&buffer.0[..bytes])?;
- let init = match opcode {
+ let body = match opcode {
proto::Opcode::Init => <&proto::InitIn>::toplevel_from(&buffer.0[HEADER_END..bytes])?,
_ => {
@@ -119,11 +121,11 @@ impl Session {
};
use std::cmp::Ordering;
- let supported = match init.major.cmp(&proto::MAJOR_VERSION) {
+ let supported = match body.major.cmp(&proto::MAJOR_VERSION) {
Ordering::Less => false,
Ordering::Equal => {
- self.proto_minor = init.minor;
+ self.proto_minor = body.minor;
self.proto_minor >= proto::REQUIRED_MINOR_VERSION
}
@@ -131,7 +133,7 @@ impl Session {
let tail = [bytes_of(&proto::MAJOR_VERSION)];
self.ok(header.unique, OutputChain::tail(&tail))?;
- return Ok(Handshake::Restart);
+ return Ok(Handshake::Restart(init));
}
};
@@ -141,8 +143,8 @@ impl Session {
"Unsupported protocol {}.{}; this build requires \
{major}.{}..={major}.{} (or a greater version \
through compatibility)",
- init.major,
- init.minor,
+ body.major,
+ body.minor,
proto::REQUIRED_MINOR_VERSION,
proto::TARGET_MINOR_VERSION,
major = proto::MAJOR_VERSION
@@ -152,40 +154,17 @@ impl Session {
return Err(FuseError::ProtocolInit);
}
- let flags = {
- use proto::InitFlags;
-
- let kernel = InitFlags::from_bits_truncate(init.flags);
- let supported = InitFlags::PARALLEL_DIROPS
- | InitFlags::ABORT_ERROR
- | InitFlags::MAX_PAGES
- | InitFlags::CACHE_SYMLINKS;
-
- kernel & supported
- };
-
- let buffer_size = page_size() * self.buffer_pages;
-
- // See fs/fuse/dev.c in the kernel source tree for details about max_write
- let max_write = buffer_size - std::mem::size_of::<(InHeader, proto::WriteIn)>();
-
- let init_out = proto::InitOut {
- major: proto::MAJOR_VERSION,
- minor: proto::TARGET_MINOR_VERSION,
- max_readahead: 0, //TODO
- flags: flags.bits(),
- max_background: 0, //TODO
- congestion_threshold: 0, //TODO
- max_write: max_write.try_into().unwrap(),
- time_gran: 1, //TODO
- max_pages: self.buffer_pages.try_into().unwrap(),
- padding: Default::default(),
- unused: Default::default(),
+ let request = Request { header, body };
+ let reply = Reply {
+ session: self,
+ unique: header.unique,
+ tail: ops::state::Init {
+ kernel_flags: proto::InitFlags::from_bits_truncate(body.flags),
+ buffer_pages: self.buffer_pages,
+ },
};
- let tail = [bytes_of(&init_out)];
- self.ok(header.unique, OutputChain::tail(&tail))?;
-
+ let _ = init((request, reply));
Ok(Handshake::Done)
}
@@ -270,7 +249,7 @@ impl Endpoint<'_> {
}
};
- let (header, opcode) = proto::InHeader::from_bytes(&buffer[..bytes?])?;
+ let (header, opcode) = InHeader::from_bytes(&buffer[..bytes?])?;
let common = IncomingCommon {
session: self.session,
buffer: &mut self.local_buffer,
@@ -317,7 +296,10 @@ impl Endpoint<'_> {
}
impl Start {
- pub async fn start(self) -> FuseResult<Arc<Session>> {
+ pub async fn start<F>(self, mut init: F) -> FuseResult<Arc<Session>>
+ where
+ F: FnOnce(Op<'_, ops::Init>) -> Done<'_>,
+ {
let session_fd = self.session_fd.into_raw_fd();
let flags = OFlag::O_NONBLOCK | OFlag::O_LARGEFILE;
@@ -344,10 +326,13 @@ impl Start {
let mut init_buffer = session.buffers.get_mut().unwrap().pop().unwrap();
loop {
- if let Handshake::Done = session.handshake(&mut init_buffer).await? {
- session.buffers.get_mut().unwrap().push(init_buffer);
- break Ok(Arc::new(session));
- }
+ init = match session.handshake(&mut init_buffer, init).await? {
+ Handshake::Restart(init) => init,
+ Handshake::Done => {
+ session.buffers.get_mut().unwrap().push(init_buffer);
+ break Ok(Arc::new(session));
+ }
+ };
}
}
@@ -359,7 +344,10 @@ impl Start {
}
}
-impl<'o, O: Operation<'o>> Incoming<'o, O> {
+impl<'o, O: Operation<'o>> Incoming<'o, O>
+where
+ O::ReplyTail: Default,
+{
pub fn op(self) -> Result<Op<'o, O>, Done<'o>> {
try_op(
self.common.session,
@@ -397,7 +385,10 @@ impl<O: for<'o> Operation<'o>> Incoming<'_, O> {
}
}
-impl<O: for<'o> Operation<'o>> Owned<O> {
+impl<O: for<'o> Operation<'o>> Owned<O>
+where
+ for<'o> <O as Operation<'o>>::ReplyTail: Default,
+{
pub fn op(&self) -> Result<Op<'_, O>, Done<'_>> {
try_op(&self.session, &self.buffer.0, self.header)
}
@@ -419,12 +410,12 @@ const HEADER_END: usize = std::mem::size_of::<InHeader>();
struct IncomingCommon<'o> {
session: &'o Arc<Session>,
buffer: &'o mut Buffer,
- header: proto::InHeader,
+ header: InHeader,
}
-enum Handshake {
+enum Handshake<F> {
Done,
- Restart,
+ Restart(F),
}
struct Buffer(Box<[u8]>);
@@ -455,8 +446,11 @@ impl Buffer {
fn try_op<'o, O: Operation<'o>>(
session: &'o Session,
bytes: &'o [u8],
- header: proto::InHeader,
-) -> Result<Op<'o, O>, Done<'o>> {
+ header: InHeader,
+) -> Result<Op<'o, O>, Done<'o>>
+where
+ O::ReplyTail: Default,
+{
let body = match Structured::toplevel_from(&bytes[HEADER_END..header.len as usize]) {
Ok(body) => body,
Err(error) => {
@@ -480,7 +474,3 @@ fn try_op<'o, O: Operation<'o>>(
Ok((request, reply))
}
-
-fn page_size() -> usize {
- sysconf(SysconfVar::PAGE_SIZE).unwrap().unwrap() as usize
-}