From 7b0f818d3848fa8cfaeb1c6160007e013060a4d4 Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Tue, 26 Aug 2025 10:41:54 +0530 Subject: [PATCH] feat: parse capture groups with `+` or `*` wildcards --- .helix/languages.toml | 9 ++++-- snippets/refactor/go/base64.kdl | 13 +++++++++ src/args.rs | 6 ++-- src/lsp.rs | 10 +++++-- src/main.rs | 30 +++++++++++++++----- src/mutation.rs | 50 ++++++++++++++++++++++----------- src/sources.rs | 29 ++++++++++++------- src/state.rs | 4 +-- 8 files changed, 109 insertions(+), 42 deletions(-) create mode 100644 snippets/refactor/go/base64.kdl diff --git a/.helix/languages.toml b/.helix/languages.toml index 2a1f216..f61e667 100644 --- a/.helix/languages.toml +++ b/.helix/languages.toml @@ -1,6 +1,11 @@ [language-server.silos] -command = "silos" +command = "./target/debug/silos" +args = ["lsp"] [[language]] name = "go" -language-servers = [ { name = "silos" } ] +language-servers = [ { name = "silos" }, "gopls" ] + +[[language]] +name = "rust" +language-servers = [ ] diff --git a/snippets/refactor/go/base64.kdl b/snippets/refactor/go/base64.kdl new file mode 100644 index 0000000..4b56ce2 --- /dev/null +++ b/snippets/refactor/go/base64.kdl @@ -0,0 +1,13 @@ +description "base64 import" +mutation { + expression "import_spec_list ((import_spec)* @imports)" + substitute { + literal "(" + literal "\n" + capture "imports" + literal "\n" + literal #""base64""# + literal "\n" + literal ")" + } +} diff --git a/src/args.rs b/src/args.rs index 64c2aa1..5c86681 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,4 +1,4 @@ -use clap::{Parser, Subcommand, Args}; +use clap::{Args, Parser, Subcommand}; use std::path::PathBuf; #[derive(Parser, Debug)] @@ -41,7 +41,7 @@ pub struct ShowCaptures { #[derive(Subcommand, Debug)] pub enum Ast { /// Dump the S expression for a given source file - DumpExpression (DumpExpression), + DumpExpression(DumpExpression), /// Show what parts of a source file gets captured by an S expression ShowCaptures(ShowCaptures), } @@ -52,7 +52,7 @@ pub enum Command { #[command(subcommand)] Ast(Ast), /// spawn a language server for use with a text editor - Lsp(Lsp) + Lsp(Lsp), } impl Lsp { diff --git a/src/lsp.rs b/src/lsp.rs index 7a31e64..d6e029a 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -60,12 +60,18 @@ impl LanguageServer for Backend { } async fn did_open(&self, params: DidOpenTextDocumentParams) { - self.body.lock().await.insert(params.text_document.uri, params.text_document.text); + self.body + .lock() + .await + .insert(params.text_document.uri, params.text_document.text); } async fn did_change(&self, params: DidChangeTextDocumentParams) { if let Some(body) = params.content_changes.into_iter().next() { - self.body.lock().await.insert(params.text_document.uri, body.text); + self.body + .lock() + .await + .insert(params.text_document.uri, body.text); } } diff --git a/src/main.rs b/src/main.rs index 87d1a47..d520c5e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,8 +13,8 @@ mod args; mod embed; mod lsp; mod mutation; -mod state; mod sources; +mod state; #[tokio::main] async fn main() -> Result<()> { @@ -26,17 +26,30 @@ async fn main() -> Result<()> { println!("{}", dump_expression(&source_file.path)?); } args::Ast::ShowCaptures(show_captures) => { - println!("{:?}", show_captures) + let source = std::fs::read_to_string(&show_captures.path).unwrap(); + let source_bytes = source.as_bytes(); + let extension = show_captures.path.extension().unwrap().to_str().unwrap(); + let langfn = state::Refactor::get_lang(extension).unwrap(); + let mut parser = tree_sitter::Parser::new(); + parser.set_language(&langfn).unwrap(); + let tree = parser.parse(source_bytes, None).unwrap(); + let root_node = tree.root_node(); + let cooked = mutation::query( + root_node, + &show_captures.expression, + &langfn, + source_bytes, + ); + println!("{:#?}", cooked); } - } return Ok(()); - }, + } args::Command::Lsp(lsp) => lsp, }; - + let (model_id, revision) = args.resolve_model_and_revision(); - + let embed = embed::Embed::new(args.gpu, &model_id, &revision)?; let mut dict = HashMap::default(); let dimensions = embed.hidden_size; @@ -80,7 +93,10 @@ async fn main() -> Result<()> { .or_insert_with(|| HNSWIndex::new(dimensions, &Default::default())); current_lang_index - .add(&embed.embed(&mutations.description)?, mutations_collection.len()) + .add( + &embed.embed(&mutations.description)?, + mutations_collection.len(), + ) .map_err(E::msg)?; mutations_collection.push(mutations); } diff --git a/src/mutation.rs b/src/mutation.rs index 6d91941..503584f 100644 --- a/src/mutation.rs +++ b/src/mutation.rs @@ -72,8 +72,6 @@ pub fn from_path>(path: P) -> Result { substitute.push(substitutor); } - let expression = format!("({expression}) @root"); - mutations.push(Mutation { expression, substitute, @@ -127,7 +125,7 @@ pub fn apply( } #[derive(Debug)] -struct QueryCooked { +pub struct QueryCooked { captures: HashMap, end: usize, start: usize, @@ -152,18 +150,20 @@ fn split_at_indices<'a>(c: &'a [u8], idx: &[usize]) -> SplitMap<'a> { SplitMap { values, indices } } -fn query<'a>( +pub fn query<'a>( node: Node<'a>, expr: &'a str, lang: &Language, source_bytes: &[u8], ) -> Vec { - let query = Query::new(lang, expr).unwrap(); + let expr = format!("({expr}) @root"); + let query = Query::new(lang, &expr).unwrap(); let mut qc = QueryCursor::new(); let mut query_matches = qc.matches(&query, node, source_bytes); let capture_names = query.capture_names(); + // println!("names: {capture_names:#?}"); let mut cooked = vec![]; @@ -171,19 +171,37 @@ fn query<'a>( let mut capture_cooked = HashMap::new(); let mut start = 0; let mut end = 0; - for cap in matcha.captures { - let Some(name) = capture_names.get(cap.index as usize) else { - continue; - }; - if *name == "root" { - start = cap.node.start_byte(); - end = cap.node.end_byte(); + if matcha.captures.is_empty() { + continue; + } + // println!("match {:#?}", matcha.id()); + + for (ix, name) in capture_names.iter().enumerate() { + let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap()); + let mut start_pos = None; + let mut end_pos = None; + // println!("matches for {name}"); + for node in nodes { + if start_pos.is_none() { + start_pos.replace(node.start_byte()); + } + end_pos.replace(node.end_byte()); + // println!("hit {node:#?}"); + } + if start_pos.or(end_pos).is_none() { continue; } - capture_cooked.insert( - name.to_string(), - cap.node.utf8_text(source_bytes).unwrap().to_string(), - ); + if *name == "root" { + start = start_pos.unwrap(); + end = end_pos.unwrap(); + continue; + } + let range = start_pos.unwrap()..end_pos.unwrap(); + // println!("match range for {name}: {:#?}", range); + let text_bytes = &source_bytes[range]; + let text = std::str::from_utf8(text_bytes).unwrap(); + // println!("text: {text}"); + capture_cooked.insert(name.to_string(), text.to_string()); } cooked.push(QueryCooked { start, diff --git a/src/sources.rs b/src/sources.rs index e6f782e..149f3eb 100644 --- a/src/sources.rs +++ b/src/sources.rs @@ -1,23 +1,32 @@ -use std::{fs, io, path::{Path, PathBuf}, collections::HashMap}; +use std::{ + collections::HashMap, + fs, io, + path::{Path, PathBuf}, +}; pub fn rule_files>(path: P) -> io::Result>> { let per_language_dirs: Vec<_> = fs::read_dir(path)? - .filter_map(|res| res.ok()) - .map(|direntry| direntry.path()) - .filter(|dir| dir.is_dir()).collect(); + .filter_map(|res| res.ok()) + .map(|direntry| direntry.path()) + .filter(|dir| dir.is_dir()) + .collect(); let mut basename_to_paths = HashMap::new(); for language_dir in per_language_dirs { - let Some(dirname) = language_dir.file_stem().and_then(|v|v.to_str()).map(|v| v.to_string()) else { + let Some(dirname) = language_dir + .file_stem() + .and_then(|v| v.to_str()) + .map(|v| v.to_string()) + else { continue; }; let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)? - .filter_map(|res| res.ok()) - .map(|entry| entry.path()) - .filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl")) - .map(|path| path.to_path_buf()) - .collect(); + .filter_map(|res| res.ok()) + .map(|entry| entry.path()) + .filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl")) + .map(|path| path.to_path_buf()) + .collect(); basename_to_paths.insert(dirname, rule_file_paths); } Ok(basename_to_paths) diff --git a/src/state.rs b/src/state.rs index e5dad70..8f6b822 100644 --- a/src/state.rs +++ b/src/state.rs @@ -4,8 +4,8 @@ use derive_more::Error; use hora::core::ann_index::ANNIndex; use hora::index::hnsw_idx::HNSWIndex; use std::collections::HashMap; -use tree_sitter::Parser; use std::path::Path; +use tree_sitter::Parser; #[derive(Debug, Display, Error)] pub enum Error { @@ -23,7 +23,7 @@ pub struct Refactor { } impl Refactor { - fn get_lang(s: &str) -> Result { + pub fn get_lang(s: &str) -> Result { Ok(match s { "go" => tree_sitter_go::LANGUAGE, "c" | "h" => tree_sitter_c::LANGUAGE,