diff --git a/src/builtins/shared.rs b/src/builtins/shared.rs index 148954de7..4d233c924 100644 --- a/src/builtins/shared.rs +++ b/src/builtins/shared.rs @@ -814,23 +814,36 @@ pub fn new(arg: Cow<'args, wstr>, want_newline: bool) -> Self { /// A helper type for extracting arguments from either argv or stdin. pub struct Arguments<'args, 'iter> { - /// The list of arguments passed to the string builtin. - args: &'iter [&'args wstr], - /// If using argv, index of the next argument to return. - argidx: &'iter mut usize, split_behavior: SplitBehavior, - /// Buffer to store what we read with the BufReader - /// Is only here to avoid allocating every time - buffer: Vec, - /// If not using argv, we read with a buffer - reader: Option>, + source: ArgvSource<'args, 'iter>, } -impl Drop for Arguments<'_, '_> { +/// Either the arguments from argv, or from stdin. +enum ArgvSource<'args, 'iter> { + /// Read arguments from argv. + Args { + // The list of arguments passed to the builtin. + args: &'iter [&'args wstr], + // Index of the next argument to return. + argidx: &'iter mut usize, + }, + /// Read arguments from stdin (possibly redirected). + Stdin { + /// Reused storage for reading. + buffer: Vec, + /// The reader to read from. + /// This is never None; we use Option to avoid closing stdin on drop. + reader: Option>, + }, +} + +impl Drop for ArgvSource<'_, '_> { fn drop(&mut self) { - if let Some(r) = self.reader.take() { - // we should not close stdin - std::mem::forget(r.into_inner()); + if let ArgvSource::Stdin { reader, .. } = self { + if let Some(reader) = reader.take() { + // we should not close stdin + std::mem::forget(reader.into_inner()); + } } } } @@ -842,20 +855,21 @@ pub fn new( streams: &mut IoStreams, chunk_size: usize, ) -> Self { - let reader = streams.stdin_is_directly_redirected.then(|| { + let source: ArgvSource = if !streams.stdin_is_directly_redirected { + ArgvSource::Args { args, argidx } + } else { let stdin_fd = streams.stdin_fd; assert!(stdin_fd >= 0, "should have a valid fd"); // safety: this should be a valid fd, and already open let fd = unsafe { File::from_raw_fd(stdin_fd) }; - BufReader::with_capacity(chunk_size, fd) - }); - + ArgvSource::Stdin { + buffer: Vec::new(), + reader: Some(BufReader::with_capacity(chunk_size, fd)), + } + }; Arguments { - args, - argidx, split_behavior: SplitBehavior::Newline, - buffer: Vec::new(), - reader, + source, } } @@ -864,9 +878,27 @@ pub fn with_split_behavior(mut self, split_behavior: SplitBehavior) -> Self { self } + /// Return the next argument by reading from argv ArgvSource. + fn get_arg_argv(&mut self) -> Option> { + let ArgvSource::Args { args, argidx } = &mut self.source else { + panic!("Not reading from argv") + }; + let arg = args.get(**argidx)?; + **argidx += 1; + let retval = InputValue::new(Cow::Borrowed(arg), /*want_newline=*/ true); + Some(retval) + } + + /// Return the next argument by reading from stdin ArgvSource. fn get_arg_stdin(&mut self) -> Option> { use SplitBehavior::*; - let reader = self.reader.as_mut().unwrap(); + let ArgvSource::Stdin { + reader: Some(reader), + buffer, + } = &mut self.source + else { + panic!("Not reading from stdin") + }; if self.split_behavior == InferNull { // we must determine if the first `PATH_MAX` bytes contains a null. @@ -882,9 +914,9 @@ fn get_arg_stdin(&mut self) -> Option> { // NOTE: C++ wrongly commented that read_blocked retries for EAGAIN let num_bytes: usize = match self.split_behavior { - Newline => reader.read_until(b'\n', &mut self.buffer), - Null => reader.read_until(b'\0', &mut self.buffer), - Never => reader.read_to_end(&mut self.buffer), + Newline => reader.read_until(b'\n', buffer), + Null => reader.read_until(b'\0', buffer), + Never => reader.read_to_end(buffer), _ => unreachable!(), } .ok()?; @@ -895,7 +927,7 @@ fn get_arg_stdin(&mut self) -> Option> { } // assert!(num_bytes == self.buffer.len()); - let (end, want_newline) = match (&self.split_behavior, self.buffer.last()) { + let (end, want_newline) = match (&self.split_behavior, buffer.last()) { // remove the newline — consumers do not expect it (Newline, Some(b'\n')) => (num_bytes - 1, true), // we are missing a trailing newline! @@ -909,8 +941,8 @@ fn get_arg_stdin(&mut self) -> Option> { _ => unreachable!(), }; - let parsed = bytes2wcstring(&self.buffer[..end]); - self.buffer.clear(); + let parsed = bytes2wcstring(&buffer[..end]); + buffer.clear(); Some(InputValue::new(Cow::Owned(parsed), want_newline)) } @@ -925,19 +957,10 @@ impl<'args> Iterator for Arguments<'args, '_> { type Item = InputValue<'args>; fn next(&mut self) -> Option { - if self.reader.is_some() { - return self.get_arg_stdin(); + match &mut self.source { + ArgvSource::Args { .. } => self.get_arg_argv(), + ArgvSource::Stdin { .. } => self.get_arg_stdin(), } - - if *self.argidx >= self.args.len() { - return None; - } - let retval = InputValue::new( - Cow::Borrowed(self.args[*self.argidx]), - /*want_newline=*/ true, - ); - *self.argidx += 1; - Some(retval) } }