1 Commits

Author SHA1 Message Date
Himadri Bhattacharjee
fcf98006c4 feat: parse refactor document using knuffel 2025-09-02 14:03:16 +05:30
7 changed files with 775 additions and 641 deletions

1191
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -16,12 +16,13 @@ kdl = "6.3.4"
serde_json = "1.0.140" serde_json = "1.0.140"
tokenizers = "0.21.4" tokenizers = "0.21.4"
tracing = "0.1.41" tracing = "0.1.41"
tracing-subscriber = "0.3.20" tracing-subscriber = "0.3.19"
tree-sitter = "0.25.8" tree-sitter = "0.25.8"
tree-sitter-c = "0.24.1" tree-sitter-c = "0.24.1"
tree-sitter-go = "0.23.4" tree-sitter-go = "0.23.4"
tree-sitter-rust = "0.24.0" tree-sitter-rust = "0.24.0"
tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] } tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] }
tower-lsp = "0.20.0" tower-lsp = "0.20.0"
tree-sitter-javascript = "0.25.0" tree-sitter-javascript = "0.23.1"
tree-sitter-cpp = "0.23.4" tree-sitter-cpp = "0.23.4"
knuffel = "3.2.0"

View File

@@ -13,7 +13,6 @@
packages = with pkgs; [ packages = with pkgs; [
stdenv.cc.cc stdenv.cc.cc
pkg-config pkg-config
bacon
]; ];
libraries = with pkgs; [ libraries = with pkgs; [

View File

@@ -38,21 +38,12 @@ pub struct ShowCaptures {
pub expression: String, pub expression: String,
} }
#[derive(Args, Debug)]
pub struct DryRun {
pub path: PathBuf,
pub edit_file: PathBuf,
}
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
pub enum Ast { pub enum Ast {
/// Dump the S expression for a given source file /// Dump the S expression for a given source file
DumpExpression(DumpExpression), DumpExpression(DumpExpression),
/// Show what parts of a source file gets captured by an S expression /// Show what parts of a source file gets captured by an S expression
ShowCaptures(ShowCaptures), ShowCaptures(ShowCaptures),
/// Test your edit snippets on a sample file
DryRun(DryRun),
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]

View File

@@ -26,28 +26,22 @@ async fn main() -> Result<()> {
println!("{}", dump_expression(&source_file.path)?); println!("{}", dump_expression(&source_file.path)?);
} }
args::Ast::ShowCaptures(show_captures) => { args::Ast::ShowCaptures(show_captures) => {
let source_bytes = std::fs::read(&show_captures.path)?; let source = std::fs::read_to_string(&show_captures.path).unwrap();
let langfn = state::lang_from_file_extension(&show_captures.path)?; let source_bytes = source.as_bytes();
let tree = state::parse_into_tree(&source_bytes, &langfn)?; let extension = show_captures.path.extension().unwrap().to_str().unwrap();
let langfn = state::Refactor::get_lang(extension).unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&langfn).unwrap();
let tree = parser.parse(source_bytes, None).unwrap();
let root_node = tree.root_node(); let root_node = tree.root_node();
let cooked = mutation::query( let cooked = mutation::query(
root_node, root_node,
&show_captures.expression, &show_captures.expression,
&langfn, &langfn,
&source_bytes, source_bytes,
); );
println!("{:#?}", cooked); println!("{:#?}", cooked);
} }
args::Ast::DryRun(dry_run) => {
let mutation_collection = mutation::from_path(dry_run.edit_file)?;
let source_bytes = std::fs::read(&dry_run.path)?;
let langfn = state::lang_from_file_extension(&dry_run.path)?;
let tree = state::parse_into_tree(&source_bytes, &langfn)?;
let root_node = tree.root_node();
let cooked =
mutation::apply(langfn, &source_bytes, root_node, &mutation_collection)?;
println!("{cooked}");
}
} }
return Ok(()); return Ok(());
} }

View File

@@ -2,90 +2,34 @@ use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use tracing::debug; use tracing::debug;
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator}; use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator};
use anyhow::Result;
use anyhow::{Result, bail}; #[derive(knuffel::Decode, Debug)]
use kdl::KdlDocument;
#[derive(Debug)]
pub struct Mutation {
pub expression: String,
pub substitute: Vec<Substitute>,
}
pub struct MutationCollection { pub struct MutationCollection {
#[knuffel(child, unwrap(argument))]
pub description: String, pub description: String,
pub mutations: Vec<Mutation>, #[knuffel(children)]
mutations: Vec<Mutation>,
} }
#[derive(Debug)] #[derive(knuffel::Decode, Debug)]
pub enum Substitute { struct Mutation {
Literal(String), #[knuffel(child, unwrap(argument))]
Capture(String), expression: String,
#[knuffel(child, unwrap(children))]
substitute: Vec<Substitute>,
}
#[derive(knuffel::Decode, Debug)]
enum Substitute {
Literal(#[knuffel(argument)] String),
Capture(#[knuffel(argument)] String),
} }
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> { pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
let contents = std::fs::read_to_string(path)?; let contents = std::fs::read_to_string(&path)?;
let doc: KdlDocument = contents.parse()?; let val = knuffel::parse(path.as_ref().to_str().unwrap(), &contents)?;
let mut mutations = vec![]; Ok(val)
let mut description = None;
for node in doc.nodes() {
let node_name = node.name().value();
if node_name != "mutation" && node_name != "description" {
bail!(
"document root must only contain `mutation` or `description` nodes: got {node_name}"
);
}
if node_name == "description" {
description.replace(
node.entry(0)
.unwrap()
.value()
.as_string()
.unwrap()
.to_string(),
);
continue;
}
let node = node.children().unwrap();
let Some(expression) = node.get_arg("expression").and_then(|v| v.as_string()) else {
bail!("mutation node must contain an expression");
};
let Some(substitute) = node.get("substitute") else {
bail!("mutation node must contain an substitute");
};
let children = substitute.children().unwrap().nodes();
let mut substitute = vec![];
for child in children {
let attrib = child.entry(0).unwrap().value().as_string().unwrap();
let substitutor = match child.name().value() {
"literal" => Substitute::Literal(attrib.to_string()),
"capture" => Substitute::Capture(attrib.to_string()),
_ => unreachable!(),
};
substitute.push(substitutor);
}
mutations.push(Mutation {
expression: expression.to_string(),
substitute,
})
}
let Some(description) = description else {
bail!("mutation collection contains no `description`");
};
Ok(MutationCollection {
description,
mutations,
})
} }
pub fn apply( pub fn apply(
@@ -179,23 +123,24 @@ pub fn query<'a>(
let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap()); let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap());
let mut start_pos = None; let mut start_pos = None;
let mut end_pos = None; let mut end_pos = None;
debug!("matches for {name}"); // println!("matches for {name}");
for node in nodes { for node in nodes {
start_pos.get_or_insert(node.start_byte()); if start_pos.is_none() {
start_pos.replace(node.start_byte());
}
end_pos.replace(node.end_byte()); end_pos.replace(node.end_byte());
debug!("hit {node:#?}"); // println!("hit {node:#?}");
} }
if start_pos.or(end_pos).is_none() {
let (Some(start_pos), Some(end_pos)) = (start_pos, end_pos) else {
continue; continue;
};
if *name == "root" {
start = start_pos;
end = end_pos;
} }
if *name == "root" {
let text_bytes = &source_bytes[start_pos..end_pos]; start = start_pos.unwrap();
end = end_pos.unwrap();
}
let range = start_pos.unwrap()..end_pos.unwrap();
// println!("match range for {name}: {:#?}", range);
let text_bytes = &source_bytes[range];
let text = std::str::from_utf8(text_bytes).unwrap(); let text = std::str::from_utf8(text_bytes).unwrap();
// println!("text: {text}"); // println!("text: {text}");
capture_cooked.insert(name.to_string(), text.to_string()); capture_cooked.insert(name.to_string(), text.to_string());

View File

@@ -23,6 +23,18 @@ pub struct Refactor {
} }
impl Refactor { impl Refactor {
pub fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
Ok(match s {
"go" => tree_sitter_go::LANGUAGE,
"c" | "h" => tree_sitter_c::LANGUAGE,
"cpp" | "hpp" => tree_sitter_cpp::LANGUAGE,
"js" | "ts" => tree_sitter_javascript::LANGUAGE,
"rs" => tree_sitter_rust::LANGUAGE,
_ => return Err(Error::UnknownLang),
}
.into())
}
pub fn search( pub fn search(
&self, &self,
lang: &str, lang: &str,
@@ -30,9 +42,17 @@ impl Refactor {
body: &str, body: &str,
top_k: usize, top_k: usize,
) -> Result<Vec<String>, Error> { ) -> Result<Vec<String>, Error> {
let langfn = lang_from_name(lang)?; let langfn = Self::get_lang(lang)?;
let source_bytes = body.as_bytes(); let mut parser = Parser::new();
let tree = parse_into_tree(source_bytes, &langfn)?; parser
.set_language(&langfn)
.map_err(|_| Error::UnknownLang)?;
let source_code = body;
let source_bytes = source_code.as_bytes();
let tree = parser
.parse(source_code, None)
.ok_or(Error::SnippetParsing)?;
let root_node = tree.root_node(); let root_node = tree.root_node();
// search for k nearest neighbors // search for k nearest neighbors
@@ -63,45 +83,24 @@ impl Refactor {
} }
} }
pub fn lang_from_name(s: &str) -> Result<tree_sitter::Language, Error> { pub fn dump_expression(path: &Path) -> Result<String, Error> {
Ok(match s {
"go" => tree_sitter_go::LANGUAGE,
"c" | "h" => tree_sitter_c::LANGUAGE,
"cpp" | "hpp" => tree_sitter_cpp::LANGUAGE,
"js" | "ts" => tree_sitter_javascript::LANGUAGE,
"rs" => tree_sitter_rust::LANGUAGE,
_ => return Err(Error::UnknownLang),
}
.into())
}
pub fn lang_from_file_extension(path: &Path) -> Result<tree_sitter::Language, Error> {
let Some(lang) = path.extension() else { let Some(lang) = path.extension() else {
return Err(Error::UnknownLang); return Err(Error::UnknownLang);
}; };
let lang = lang.to_str().ok_or(Error::UnknownLang)?; let lang = lang.to_str().ok_or(Error::UnknownLang)?;
lang_from_name(lang) let langfn = Refactor::get_lang(lang)?;
}
// parses `body` written in the language `langfn` into tree sitter AST
pub fn parse_into_tree(
body: &[u8],
langfn: &tree_sitter::Language,
) -> Result<tree_sitter::Tree, Error> {
let mut parser = Parser::new(); let mut parser = Parser::new();
parser parser
.set_language(langfn) .set_language(&langfn)
.map_err(|_| Error::UnknownLang)?; .map_err(|_| Error::UnknownLang)?;
let tree = parser.parse(body, None).ok_or(Error::SnippetParsing)?;
Ok(tree)
}
pub fn dump_expression(path: &Path) -> Result<String, Error> { let source_code = std::fs::read_to_string(path).map_err(|_| Error::SnippetParsing)?;
let source_bytes = std::fs::read(path).map_err(|_| Error::SnippetParsing)?; let source_bytes = source_code.as_bytes();
let tree = parser
let tree = parse_into_tree(&source_bytes, &lang_from_file_extension(path)?)?; .parse(source_bytes, None)
.ok_or(Error::SnippetParsing)?;
Ok(tree.root_node().to_sexp().to_string()) let root_node = tree.root_node();
Ok(root_node.to_sexp().to_string())
} }
pub struct Generate { pub struct Generate {