feat: parse capture groups with + or * wildcards

This commit is contained in:
Himadri Bhattacharjee
2025-08-26 10:41:54 +05:30
parent e5602c688c
commit 7b0f818d38
8 changed files with 109 additions and 42 deletions

View File

@@ -1,6 +1,11 @@
[language-server.silos]
command = "silos"
command = "./target/debug/silos"
args = ["lsp"]
[[language]]
name = "go"
language-servers = [ { name = "silos" } ]
language-servers = [ { name = "silos" }, "gopls" ]
[[language]]
name = "rust"
language-servers = [ ]

View File

@@ -0,0 +1,13 @@
description "base64 import"
mutation {
expression "import_spec_list ((import_spec)* @imports)"
substitute {
literal "("
literal "\n"
capture "imports"
literal "\n"
literal #""base64""#
literal "\n"
literal ")"
}
}

View File

@@ -1,4 +1,4 @@
use clap::{Parser, Subcommand, Args};
use clap::{Args, Parser, Subcommand};
use std::path::PathBuf;
#[derive(Parser, Debug)]
@@ -41,7 +41,7 @@ pub struct ShowCaptures {
#[derive(Subcommand, Debug)]
pub enum Ast {
/// Dump the S expression for a given source file
DumpExpression (DumpExpression),
DumpExpression(DumpExpression),
/// Show what parts of a source file gets captured by an S expression
ShowCaptures(ShowCaptures),
}
@@ -52,7 +52,7 @@ pub enum Command {
#[command(subcommand)]
Ast(Ast),
/// spawn a language server for use with a text editor
Lsp(Lsp)
Lsp(Lsp),
}
impl Lsp {

View File

@@ -60,12 +60,18 @@ impl LanguageServer for Backend {
}
async fn did_open(&self, params: DidOpenTextDocumentParams) {
self.body.lock().await.insert(params.text_document.uri, params.text_document.text);
self.body
.lock()
.await
.insert(params.text_document.uri, params.text_document.text);
}
async fn did_change(&self, params: DidChangeTextDocumentParams) {
if let Some(body) = params.content_changes.into_iter().next() {
self.body.lock().await.insert(params.text_document.uri, body.text);
self.body
.lock()
.await
.insert(params.text_document.uri, body.text);
}
}

View File

@@ -13,8 +13,8 @@ mod args;
mod embed;
mod lsp;
mod mutation;
mod state;
mod sources;
mod state;
#[tokio::main]
async fn main() -> Result<()> {
@@ -26,17 +26,30 @@ async fn main() -> Result<()> {
println!("{}", dump_expression(&source_file.path)?);
}
args::Ast::ShowCaptures(show_captures) => {
println!("{:?}", show_captures)
let source = std::fs::read_to_string(&show_captures.path).unwrap();
let source_bytes = source.as_bytes();
let extension = show_captures.path.extension().unwrap().to_str().unwrap();
let langfn = state::Refactor::get_lang(extension).unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&langfn).unwrap();
let tree = parser.parse(source_bytes, None).unwrap();
let root_node = tree.root_node();
let cooked = mutation::query(
root_node,
&show_captures.expression,
&langfn,
source_bytes,
);
println!("{:#?}", cooked);
}
}
return Ok(());
},
}
args::Command::Lsp(lsp) => lsp,
};
let (model_id, revision) = args.resolve_model_and_revision();
let embed = embed::Embed::new(args.gpu, &model_id, &revision)?;
let mut dict = HashMap::default();
let dimensions = embed.hidden_size;
@@ -80,7 +93,10 @@ async fn main() -> Result<()> {
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
current_lang_index
.add(&embed.embed(&mutations.description)?, mutations_collection.len())
.add(
&embed.embed(&mutations.description)?,
mutations_collection.len(),
)
.map_err(E::msg)?;
mutations_collection.push(mutations);
}

View File

@@ -72,8 +72,6 @@ pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
substitute.push(substitutor);
}
let expression = format!("({expression}) @root");
mutations.push(Mutation {
expression,
substitute,
@@ -127,7 +125,7 @@ pub fn apply(
}
#[derive(Debug)]
struct QueryCooked {
pub struct QueryCooked {
captures: HashMap<String, String>,
end: usize,
start: usize,
@@ -152,18 +150,20 @@ fn split_at_indices<'a>(c: &'a [u8], idx: &[usize]) -> SplitMap<'a> {
SplitMap { values, indices }
}
fn query<'a>(
pub fn query<'a>(
node: Node<'a>,
expr: &'a str,
lang: &Language,
source_bytes: &[u8],
) -> Vec<QueryCooked> {
let query = Query::new(lang, expr).unwrap();
let expr = format!("({expr}) @root");
let query = Query::new(lang, &expr).unwrap();
let mut qc = QueryCursor::new();
let mut query_matches = qc.matches(&query, node, source_bytes);
let capture_names = query.capture_names();
// println!("names: {capture_names:#?}");
let mut cooked = vec![];
@@ -171,19 +171,37 @@ fn query<'a>(
let mut capture_cooked = HashMap::new();
let mut start = 0;
let mut end = 0;
for cap in matcha.captures {
let Some(name) = capture_names.get(cap.index as usize) else {
continue;
};
if *name == "root" {
start = cap.node.start_byte();
end = cap.node.end_byte();
if matcha.captures.is_empty() {
continue;
}
// println!("match {:#?}", matcha.id());
for (ix, name) in capture_names.iter().enumerate() {
let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap());
let mut start_pos = None;
let mut end_pos = None;
// println!("matches for {name}");
for node in nodes {
if start_pos.is_none() {
start_pos.replace(node.start_byte());
}
end_pos.replace(node.end_byte());
// println!("hit {node:#?}");
}
if start_pos.or(end_pos).is_none() {
continue;
}
capture_cooked.insert(
name.to_string(),
cap.node.utf8_text(source_bytes).unwrap().to_string(),
);
if *name == "root" {
start = start_pos.unwrap();
end = end_pos.unwrap();
continue;
}
let range = start_pos.unwrap()..end_pos.unwrap();
// println!("match range for {name}: {:#?}", range);
let text_bytes = &source_bytes[range];
let text = std::str::from_utf8(text_bytes).unwrap();
// println!("text: {text}");
capture_cooked.insert(name.to_string(), text.to_string());
}
cooked.push(QueryCooked {
start,

View File

@@ -1,23 +1,32 @@
use std::{fs, io, path::{Path, PathBuf}, collections::HashMap};
use std::{
collections::HashMap,
fs, io,
path::{Path, PathBuf},
};
pub fn rule_files<P: AsRef<Path>>(path: P) -> io::Result<HashMap<String, Vec<PathBuf>>> {
let per_language_dirs: Vec<_> = fs::read_dir(path)?
.filter_map(|res| res.ok())
.map(|direntry| direntry.path())
.filter(|dir| dir.is_dir()).collect();
.filter_map(|res| res.ok())
.map(|direntry| direntry.path())
.filter(|dir| dir.is_dir())
.collect();
let mut basename_to_paths = HashMap::new();
for language_dir in per_language_dirs {
let Some(dirname) = language_dir.file_stem().and_then(|v|v.to_str()).map(|v| v.to_string()) else {
let Some(dirname) = language_dir
.file_stem()
.and_then(|v| v.to_str())
.map(|v| v.to_string())
else {
continue;
};
let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)?
.filter_map(|res| res.ok())
.map(|entry| entry.path())
.filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl"))
.map(|path| path.to_path_buf())
.collect();
.filter_map(|res| res.ok())
.map(|entry| entry.path())
.filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl"))
.map(|path| path.to_path_buf())
.collect();
basename_to_paths.insert(dirname, rule_file_paths);
}
Ok(basename_to_paths)

View File

@@ -4,8 +4,8 @@ use derive_more::Error;
use hora::core::ann_index::ANNIndex;
use hora::index::hnsw_idx::HNSWIndex;
use std::collections::HashMap;
use tree_sitter::Parser;
use std::path::Path;
use tree_sitter::Parser;
#[derive(Debug, Display, Error)]
pub enum Error {
@@ -23,7 +23,7 @@ pub struct Refactor {
}
impl Refactor {
fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
pub fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
Ok(match s {
"go" => tree_sitter_go::LANGUAGE,
"c" | "h" => tree_sitter_c::LANGUAGE,