diff --git a/src/args.rs b/src/args.rs index 5c86681..7163b3c 100644 --- a/src/args.rs +++ b/src/args.rs @@ -38,12 +38,21 @@ pub struct ShowCaptures { pub expression: String, } +#[derive(Args, Debug)] +pub struct DryRun { + pub path: PathBuf, + pub edit_file: PathBuf, +} + #[derive(Subcommand, Debug)] pub enum Ast { /// Dump the S expression for a given source file DumpExpression(DumpExpression), /// Show what parts of a source file gets captured by an S expression ShowCaptures(ShowCaptures), + + /// Test your edit snippets on a sample file + DryRun(DryRun), } #[derive(Subcommand, Debug)] diff --git a/src/main.rs b/src/main.rs index d520c5e..e87f8c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,22 +26,28 @@ async fn main() -> Result<()> { println!("{}", dump_expression(&source_file.path)?); } args::Ast::ShowCaptures(show_captures) => { - let source = std::fs::read_to_string(&show_captures.path).unwrap(); - let source_bytes = source.as_bytes(); - let extension = show_captures.path.extension().unwrap().to_str().unwrap(); - let langfn = state::Refactor::get_lang(extension).unwrap(); - let mut parser = tree_sitter::Parser::new(); - parser.set_language(&langfn).unwrap(); - let tree = parser.parse(source_bytes, None).unwrap(); + let source_bytes = std::fs::read(&show_captures.path)?; + let langfn = state::lang_from_file_extension(&show_captures.path)?; + let tree = state::parse_into_tree(&source_bytes, &langfn)?; let root_node = tree.root_node(); let cooked = mutation::query( root_node, &show_captures.expression, &langfn, - source_bytes, + &source_bytes, ); println!("{:#?}", cooked); } + args::Ast::DryRun(dry_run) => { + let mutation_collection = mutation::from_path(dry_run.edit_file)?; + let source_bytes = std::fs::read(&dry_run.path)?; + let langfn = state::lang_from_file_extension(&dry_run.path)?; + let tree = state::parse_into_tree(&source_bytes, &langfn)?; + let root_node = tree.root_node(); + let cooked = + mutation::apply(langfn, &source_bytes, root_node, &mutation_collection)?; + println!("{cooked}"); + } } return Ok(()); } diff --git a/src/mutation.rs b/src/mutation.rs index 954f453..2f8d9df 100644 --- a/src/mutation.rs +++ b/src/mutation.rs @@ -179,24 +179,23 @@ pub fn query<'a>( let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap()); let mut start_pos = None; let mut end_pos = None; - // println!("matches for {name}"); + debug!("matches for {name}"); for node in nodes { - if start_pos.is_none() { - start_pos.replace(node.start_byte()); - } + start_pos.get_or_insert(node.start_byte()); end_pos.replace(node.end_byte()); - // println!("hit {node:#?}"); + debug!("hit {node:#?}"); } - if start_pos.or(end_pos).is_none() { + + let (Some(start_pos), Some(end_pos)) = (start_pos, end_pos) else { continue; - } + }; + if *name == "root" { - start = start_pos.unwrap(); - end = end_pos.unwrap(); + start = start_pos; + end = end_pos; } - let range = start_pos.unwrap()..end_pos.unwrap(); - // println!("match range for {name}: {:#?}", range); - let text_bytes = &source_bytes[range]; + + let text_bytes = &source_bytes[start_pos..end_pos]; let text = std::str::from_utf8(text_bytes).unwrap(); // println!("text: {text}"); capture_cooked.insert(name.to_string(), text.to_string()); diff --git a/src/state.rs b/src/state.rs index 8f6b822..66144e8 100644 --- a/src/state.rs +++ b/src/state.rs @@ -23,18 +23,6 @@ pub struct Refactor { } impl Refactor { - pub fn get_lang(s: &str) -> Result { - Ok(match s { - "go" => tree_sitter_go::LANGUAGE, - "c" | "h" => tree_sitter_c::LANGUAGE, - "cpp" | "hpp" => tree_sitter_cpp::LANGUAGE, - "js" | "ts" => tree_sitter_javascript::LANGUAGE, - "rs" => tree_sitter_rust::LANGUAGE, - _ => return Err(Error::UnknownLang), - } - .into()) - } - pub fn search( &self, lang: &str, @@ -42,17 +30,9 @@ impl Refactor { body: &str, top_k: usize, ) -> Result, Error> { - let langfn = Self::get_lang(lang)?; - let mut parser = Parser::new(); - parser - .set_language(&langfn) - .map_err(|_| Error::UnknownLang)?; - - let source_code = body; - let source_bytes = source_code.as_bytes(); - let tree = parser - .parse(source_code, None) - .ok_or(Error::SnippetParsing)?; + let langfn = lang_from_name(lang)?; + let source_bytes = body.as_bytes(); + let tree = parse_into_tree(source_bytes, &langfn)?; let root_node = tree.root_node(); // search for k nearest neighbors @@ -83,24 +63,45 @@ impl Refactor { } } -pub fn dump_expression(path: &Path) -> Result { +pub fn lang_from_name(s: &str) -> Result { + Ok(match s { + "go" => tree_sitter_go::LANGUAGE, + "c" | "h" => tree_sitter_c::LANGUAGE, + "cpp" | "hpp" => tree_sitter_cpp::LANGUAGE, + "js" | "ts" => tree_sitter_javascript::LANGUAGE, + "rs" => tree_sitter_rust::LANGUAGE, + _ => return Err(Error::UnknownLang), + } + .into()) +} + +pub fn lang_from_file_extension(path: &Path) -> Result { let Some(lang) = path.extension() else { return Err(Error::UnknownLang); }; let lang = lang.to_str().ok_or(Error::UnknownLang)?; - let langfn = Refactor::get_lang(lang)?; + lang_from_name(lang) +} + +// parses `body` written in the language `langfn` into tree sitter AST +pub fn parse_into_tree( + body: &[u8], + langfn: &tree_sitter::Language, +) -> Result { let mut parser = Parser::new(); parser - .set_language(&langfn) + .set_language(langfn) .map_err(|_| Error::UnknownLang)?; + let tree = parser.parse(body, None).ok_or(Error::SnippetParsing)?; + Ok(tree) +} - let source_code = std::fs::read_to_string(path).map_err(|_| Error::SnippetParsing)?; - let source_bytes = source_code.as_bytes(); - let tree = parser - .parse(source_bytes, None) - .ok_or(Error::SnippetParsing)?; - let root_node = tree.root_node(); - Ok(root_node.to_sexp().to_string()) +pub fn dump_expression(path: &Path) -> Result { + let source_bytes = std::fs::read(path).map_err(|_| Error::SnippetParsing)?; + + let tree = parse_into_tree(&source_bytes, &lang_from_file_extension(path)?)?; + + Ok(tree.root_node().to_sexp().to_string()) } pub struct Generate {