From 3c5efb2dae36cc08b8975ea666d1bb8fccbb3c58 Mon Sep 17 00:00:00 2001 From: LIAUD Corentin Date: Sat, 13 Jul 2024 17:40:36 +0200 Subject: [PATCH] fix: fix stdin read when using `shell` function --- Cargo.toml | 1 + src/commands/shell.rs | 78 +++++++++++++++++++++++++++---------------- 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 065ec94..b636917 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ path = "examples/adb_cli.rs" byteorder = { version = "1.5.0" } chrono = { version = "0.4.38" } lazy_static = { version = "1.5.0" } +mio = { version = "1.0.0", features = ["os-ext", "os-poll"] } regex = { version = "1.10.5", features = ["perf", "std", "unicode"] } termios = { version = "0.3.3" } thiserror = { version = "1.0.61" } diff --git a/src/commands/shell.rs b/src/commands/shell.rs index 0b127f2..61a7680 100644 --- a/src/commands/shell.rs +++ b/src/commands/shell.rs @@ -1,4 +1,10 @@ -use std::io::{ErrorKind, Read, Write}; +use std::{ + io::{self, Read, Write}, + sync::mpsc, + time::Duration, +}; + +use mio::{unix::SourceFd, Events, Interest, Poll, Token}; use crate::{ adb_termios::ADBTermios, @@ -6,6 +12,19 @@ use crate::{ AdbTcpConnection, Result, RustADBError, }; +const STDIN: Token = Token(0); +const BUFFER_SIZE: usize = 512; +const POLL_DURATION: Duration = Duration::from_millis(100); + +fn setup_poll_stdin() -> std::result::Result { + let poll = Poll::new()?; + let stdin_fd = 0; + poll.registry() + .register(&mut SourceFd(&stdin_fd), STDIN, Interest::READABLE)?; + + Ok(poll) +} + impl AdbTcpConnection { /// Runs 'command' in a shell on the device, and return its output and error streams. pub fn shell_command( @@ -89,24 +108,18 @@ impl AdbTcpConnection { // let read_stream = Arc::new(self.tcp_stream); let mut read_stream = self.tcp_stream.try_clone()?; - // Writing thread - let mut write_stream = read_stream.try_clone()?; - let writer_t = std::thread::spawn(move || -> Result<()> { - let mut buf = [0; 1024]; - loop { - let size = std::io::stdin().read(&mut buf)?; + let (tx, rx) = mpsc::channel::(); - write_stream.write_all(&buf[0..size])?; - } - }); + let mut write_stream = read_stream.try_clone()?; // Reading thread - let reader_t = std::thread::spawn(move || -> Result<()> { - const BUFFER_SIZE: usize = 512; + std::thread::spawn(move || -> Result<()> { loop { let mut buffer = [0; BUFFER_SIZE]; match read_stream.read(&mut buffer) { Ok(0) => { + let _ = tx.send(true); + read_stream.shutdown(std::net::Shutdown::Both)?; return Ok(()); } Ok(size) => { @@ -120,24 +133,33 @@ impl AdbTcpConnection { } }); - if let Err(e) = reader_t.join().unwrap() { - match e { - RustADBError::IOError(e) if e.kind() == ErrorKind::BrokenPipe => {} - _ => { - return Err(e); + let mut buf = [0; BUFFER_SIZE]; + let mut events = Events::with_capacity(1); + + let mut poll = setup_poll_stdin()?; + + // Polling either by checking that reading socket hasn't been closed, and if is there is something to read on stdin. + loop { + poll.poll(&mut events, Some(POLL_DURATION))?; + match rx.try_recv() { + Ok(_) | Err(mpsc::TryRecvError::Disconnected) => return Ok(()), + Err(mpsc::TryRecvError::Empty) => (), + } + + for event in events.iter() { + match event.token() { + STDIN => { + let size = match std::io::stdin().read(&mut buf) { + Ok(0) => return Ok(()), + Ok(size) => size, + Err(_) => return Ok(()), + }; + + write_stream.write_all(&buf[0..size])?; + } + _ => unreachable!(), } } } - - if let Err(e) = writer_t.join().unwrap() { - match e { - RustADBError::IOError(e) if e.kind() == ErrorKind::BrokenPipe => {} - _ => { - return Err(e); - } - } - } - - Ok(()) } }