From 5308f999f95343d3d232e6e9258ea607f0a05dad Mon Sep 17 00:00:00 2001 From: Alejandro Soto Date: Tue, 28 Dec 2021 19:08:33 -0600 Subject: Reimplement Forget/BatchForget --- src/proto.rs | 71 +++++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 13 deletions(-) (limited to 'src/proto.rs') diff --git a/src/proto.rs b/src/proto.rs index f15aaff..7ef3415 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,7 +1,7 @@ // Based on libfuse/include/fuse_kernel.h use bitflags::bitflags; -use bytemuck::{self, Pod}; +use bytemuck::{self, try_cast_slice, try_from_bytes, Pod}; use bytemuck_derive::{Pod, Zeroable}; use num_enum::TryFromPrimitive; use std::{convert::TryFrom, ffi::CStr, fmt}; @@ -15,16 +15,21 @@ pub const TARGET_MINOR_VERSION: u32 = 32; pub const REQUIRED_MINOR_VERSION: u32 = 31; pub trait Structured<'o>: Sized { - fn split_from(bytes: &'o [u8], last: bool) -> FuseResult<(Self, &'o [u8])>; + fn split_from(bytes: &'o [u8], header: &InHeader, last: bool) -> FuseResult<(Self, &'o [u8])>; - fn toplevel_from(bytes: &'o [u8]) -> FuseResult { - match Self::split_from(bytes, true)? { + fn toplevel_from(bytes: &'o [u8], header: &InHeader) -> FuseResult { + match Self::split_from(bytes, header, true)? { (ok, end) if end.is_empty() => Ok(ok), _ => Err(FuseError::BadLength), } } } +pub enum OpcodeSelect { + Match(L), + Alt(R), +} + #[derive(Pod, Zeroable, Copy, Clone)] #[repr(C)] pub struct InHeader { @@ -565,13 +570,48 @@ pub struct CopyFileRangeIn { } impl<'o> Structured<'o> for () { - fn split_from(bytes: &'o [u8], _last: bool) -> FuseResult<(Self, &'o [u8])> { + fn split_from(bytes: &'o [u8], _: &InHeader, _last: bool) -> FuseResult<(Self, &'o [u8])> { Ok(((), bytes)) } } +impl<'o, T, U> Structured<'o> for (T, U) +where + T: Structured<'o>, + U: Structured<'o>, +{ + fn split_from(bytes: &'o [u8], header: &InHeader, last: bool) -> FuseResult<(Self, &'o [u8])> { + let (first, bytes) = T::split_from(bytes, header, false)?; + let (second, end) = U::split_from(bytes, header, last)?; + Ok(((first, second), end)) + } +} + +impl<'o, T: Pod> Structured<'o> for &'o T { + fn split_from(bytes: &'o [u8], _: &InHeader, _last: bool) -> FuseResult<(Self, &'o [u8])> { + let (bytes, next_bytes) = bytes.split_at(bytes.len().min(std::mem::size_of::())); + match try_from_bytes(bytes) { + Ok(t) => Ok((t, next_bytes)), + Err(_) => Err(FuseError::Truncated), + } + } +} + +impl<'o, T: Pod> Structured<'o> for &'o [T] { + fn split_from(bytes: &'o [u8], _header: &InHeader, last: bool) -> FuseResult<(Self, &'o [u8])> { + if !last { + unimplemented!(); + } + + match try_cast_slice(bytes) { + Ok(slice) => Ok((slice, &[])), + Err(_) => Err(FuseError::Truncated), + } + } +} + impl<'o> Structured<'o> for &'o CStr { - fn split_from(bytes: &'o [u8], last: bool) -> FuseResult<(Self, &'o [u8])> { + fn split_from(bytes: &'o [u8], _header: &InHeader, last: bool) -> FuseResult<(Self, &'o [u8])> { let (cstr, after_cstr): (&[u8], &[u8]) = if last { (bytes, &[]) } else { @@ -586,19 +626,24 @@ impl<'o> Structured<'o> for &'o CStr { } } -impl<'o, T: Pod> Structured<'o> for &'o T { - fn split_from(bytes: &'o [u8], _last: bool) -> FuseResult<(Self, &'o [u8])> { - let (bytes, next_bytes) = bytes.split_at(bytes.len().min(std::mem::size_of::())); - match bytemuck::try_from_bytes(bytes) { - Ok(t) => Ok((t, next_bytes)), - Err(_) => Err(FuseError::Truncated), +impl<'o, L, R, const OP: u32> Structured<'o> for OpcodeSelect +where + L: Structured<'o>, + R: Structured<'o>, +{ + fn split_from(bytes: &'o [u8], header: &InHeader, last: bool) -> FuseResult<(Self, &'o [u8])> { + if header.opcode == OP { + L::split_from(bytes, header, last).map(|(l, end)| (OpcodeSelect::Match(l), end)) + } else { + R::split_from(bytes, header, last).map(|(r, end)| (OpcodeSelect::Alt(r), end)) } } } impl InHeader { pub fn from_bytes(bytes: &[u8]) -> FuseResult<(Self, Opcode)> { - let (header, _) = <&InHeader>::split_from(bytes, false)?; + let header_bytes = &bytes[..bytes.len().min(std::mem::size_of::())]; + let header = try_from_bytes::(header_bytes).map_err(|_| FuseError::Truncated)?; if header.len as usize != bytes.len() { return Err(FuseError::BadLength); -- cgit v1.2.3