1 Commits

Author SHA1 Message Date
Himadri Bhattacharjee
fcf98006c4 feat: parse refactor document using knuffel 2025-09-02 14:03:16 +05:30
7 changed files with 775 additions and 641 deletions

1191
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -16,12 +16,13 @@ kdl = "6.3.4"
serde_json = "1.0.140"
tokenizers = "0.21.4"
tracing = "0.1.41"
tracing-subscriber = "0.3.20"
tracing-subscriber = "0.3.19"
tree-sitter = "0.25.8"
tree-sitter-c = "0.24.1"
tree-sitter-go = "0.23.4"
tree-sitter-rust = "0.24.0"
tokio = { version = "1.45.1", features = ["io-std", "macros", "rt", "rt-multi-thread"] }
tower-lsp = "0.20.0"
tree-sitter-javascript = "0.25.0"
tree-sitter-javascript = "0.23.1"
tree-sitter-cpp = "0.23.4"
knuffel = "3.2.0"

View File

@@ -13,7 +13,6 @@
packages = with pkgs; [
stdenv.cc.cc
pkg-config
bacon
];
libraries = with pkgs; [

View File

@@ -38,21 +38,12 @@ pub struct ShowCaptures {
pub expression: String,
}
#[derive(Args, Debug)]
pub struct DryRun {
pub path: PathBuf,
pub edit_file: PathBuf,
}
#[derive(Subcommand, Debug)]
pub enum Ast {
/// Dump the S expression for a given source file
DumpExpression(DumpExpression),
/// Show what parts of a source file gets captured by an S expression
ShowCaptures(ShowCaptures),
/// Test your edit snippets on a sample file
DryRun(DryRun),
}
#[derive(Subcommand, Debug)]

View File

@@ -26,28 +26,22 @@ async fn main() -> Result<()> {
println!("{}", dump_expression(&source_file.path)?);
}
args::Ast::ShowCaptures(show_captures) => {
let source_bytes = std::fs::read(&show_captures.path)?;
let langfn = state::lang_from_file_extension(&show_captures.path)?;
let tree = state::parse_into_tree(&source_bytes, &langfn)?;
let source = std::fs::read_to_string(&show_captures.path).unwrap();
let source_bytes = source.as_bytes();
let extension = show_captures.path.extension().unwrap().to_str().unwrap();
let langfn = state::Refactor::get_lang(extension).unwrap();
let mut parser = tree_sitter::Parser::new();
parser.set_language(&langfn).unwrap();
let tree = parser.parse(source_bytes, None).unwrap();
let root_node = tree.root_node();
let cooked = mutation::query(
root_node,
&show_captures.expression,
&langfn,
&source_bytes,
source_bytes,
);
println!("{:#?}", cooked);
}
args::Ast::DryRun(dry_run) => {
let mutation_collection = mutation::from_path(dry_run.edit_file)?;
let source_bytes = std::fs::read(&dry_run.path)?;
let langfn = state::lang_from_file_extension(&dry_run.path)?;
let tree = state::parse_into_tree(&source_bytes, &langfn)?;
let root_node = tree.root_node();
let cooked =
mutation::apply(langfn, &source_bytes, root_node, &mutation_collection)?;
println!("{cooked}");
}
}
return Ok(());
}

View File

@@ -2,90 +2,34 @@ use std::collections::HashMap;
use std::path::Path;
use tracing::debug;
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator};
use anyhow::Result;
use anyhow::{Result, bail};
use kdl::KdlDocument;
#[derive(Debug)]
pub struct Mutation {
pub expression: String,
pub substitute: Vec<Substitute>,
}
#[derive(knuffel::Decode, Debug)]
pub struct MutationCollection {
#[knuffel(child, unwrap(argument))]
pub description: String,
pub mutations: Vec<Mutation>,
#[knuffel(children)]
mutations: Vec<Mutation>,
}
#[derive(Debug)]
pub enum Substitute {
Literal(String),
Capture(String),
#[derive(knuffel::Decode, Debug)]
struct Mutation {
#[knuffel(child, unwrap(argument))]
expression: String,
#[knuffel(child, unwrap(children))]
substitute: Vec<Substitute>,
}
#[derive(knuffel::Decode, Debug)]
enum Substitute {
Literal(#[knuffel(argument)] String),
Capture(#[knuffel(argument)] String),
}
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
let contents = std::fs::read_to_string(path)?;
let doc: KdlDocument = contents.parse()?;
let mut mutations = vec![];
let mut description = None;
for node in doc.nodes() {
let node_name = node.name().value();
if node_name != "mutation" && node_name != "description" {
bail!(
"document root must only contain `mutation` or `description` nodes: got {node_name}"
);
}
if node_name == "description" {
description.replace(
node.entry(0)
.unwrap()
.value()
.as_string()
.unwrap()
.to_string(),
);
continue;
}
let node = node.children().unwrap();
let Some(expression) = node.get_arg("expression").and_then(|v| v.as_string()) else {
bail!("mutation node must contain an expression");
};
let Some(substitute) = node.get("substitute") else {
bail!("mutation node must contain an substitute");
};
let children = substitute.children().unwrap().nodes();
let mut substitute = vec![];
for child in children {
let attrib = child.entry(0).unwrap().value().as_string().unwrap();
let substitutor = match child.name().value() {
"literal" => Substitute::Literal(attrib.to_string()),
"capture" => Substitute::Capture(attrib.to_string()),
_ => unreachable!(),
};
substitute.push(substitutor);
}
mutations.push(Mutation {
expression: expression.to_string(),
substitute,
})
}
let Some(description) = description else {
bail!("mutation collection contains no `description`");
};
Ok(MutationCollection {
description,
mutations,
})
let contents = std::fs::read_to_string(&path)?;
let val = knuffel::parse(path.as_ref().to_str().unwrap(), &contents)?;
Ok(val)
}
pub fn apply(
@@ -179,23 +123,24 @@ pub fn query<'a>(
let nodes = matcha.nodes_for_capture_index(ix.try_into().unwrap());
let mut start_pos = None;
let mut end_pos = None;
debug!("matches for {name}");
// println!("matches for {name}");
for node in nodes {
start_pos.get_or_insert(node.start_byte());
if start_pos.is_none() {
start_pos.replace(node.start_byte());
}
end_pos.replace(node.end_byte());
debug!("hit {node:#?}");
// println!("hit {node:#?}");
}
let (Some(start_pos), Some(end_pos)) = (start_pos, end_pos) else {
if start_pos.or(end_pos).is_none() {
continue;
};
if *name == "root" {
start = start_pos;
end = end_pos;
}
let text_bytes = &source_bytes[start_pos..end_pos];
if *name == "root" {
start = start_pos.unwrap();
end = end_pos.unwrap();
}
let range = start_pos.unwrap()..end_pos.unwrap();
// println!("match range for {name}: {:#?}", range);
let text_bytes = &source_bytes[range];
let text = std::str::from_utf8(text_bytes).unwrap();
// println!("text: {text}");
capture_cooked.insert(name.to_string(), text.to_string());

View File

@@ -23,6 +23,18 @@ pub struct Refactor {
}
impl Refactor {
pub fn get_lang(s: &str) -> Result<tree_sitter::Language, Error> {
Ok(match s {
"go" => tree_sitter_go::LANGUAGE,
"c" | "h" => tree_sitter_c::LANGUAGE,
"cpp" | "hpp" => tree_sitter_cpp::LANGUAGE,
"js" | "ts" => tree_sitter_javascript::LANGUAGE,
"rs" => tree_sitter_rust::LANGUAGE,
_ => return Err(Error::UnknownLang),
}
.into())
}
pub fn search(
&self,
lang: &str,
@@ -30,9 +42,17 @@ impl Refactor {
body: &str,
top_k: usize,
) -> Result<Vec<String>, Error> {
let langfn = lang_from_name(lang)?;
let source_bytes = body.as_bytes();
let tree = parse_into_tree(source_bytes, &langfn)?;
let langfn = Self::get_lang(lang)?;
let mut parser = Parser::new();
parser
.set_language(&langfn)
.map_err(|_| Error::UnknownLang)?;
let source_code = body;
let source_bytes = source_code.as_bytes();
let tree = parser
.parse(source_code, None)
.ok_or(Error::SnippetParsing)?;
let root_node = tree.root_node();
// search for k nearest neighbors
@@ -63,45 +83,24 @@ impl Refactor {
}
}
pub fn lang_from_name(s: &str) -> Result<tree_sitter::Language, Error> {
Ok(match s {
"go" => tree_sitter_go::LANGUAGE,
"c" | "h" => tree_sitter_c::LANGUAGE,
"cpp" | "hpp" => tree_sitter_cpp::LANGUAGE,
"js" | "ts" => tree_sitter_javascript::LANGUAGE,
"rs" => tree_sitter_rust::LANGUAGE,
_ => return Err(Error::UnknownLang),
}
.into())
}
pub fn lang_from_file_extension(path: &Path) -> Result<tree_sitter::Language, Error> {
pub fn dump_expression(path: &Path) -> Result<String, Error> {
let Some(lang) = path.extension() else {
return Err(Error::UnknownLang);
};
let lang = lang.to_str().ok_or(Error::UnknownLang)?;
lang_from_name(lang)
}
// parses `body` written in the language `langfn` into tree sitter AST
pub fn parse_into_tree(
body: &[u8],
langfn: &tree_sitter::Language,
) -> Result<tree_sitter::Tree, Error> {
let langfn = Refactor::get_lang(lang)?;
let mut parser = Parser::new();
parser
.set_language(langfn)
.set_language(&langfn)
.map_err(|_| Error::UnknownLang)?;
let tree = parser.parse(body, None).ok_or(Error::SnippetParsing)?;
Ok(tree)
}
pub fn dump_expression(path: &Path) -> Result<String, Error> {
let source_bytes = std::fs::read(path).map_err(|_| Error::SnippetParsing)?;
let tree = parse_into_tree(&source_bytes, &lang_from_file_extension(path)?)?;
Ok(tree.root_node().to_sexp().to_string())
let source_code = std::fs::read_to_string(path).map_err(|_| Error::SnippetParsing)?;
let source_bytes = source_code.as_bytes();
let tree = parser
.parse(source_bytes, None)
.ok_or(Error::SnippetParsing)?;
let root_node = tree.root_node();
Ok(root_node.to_sexp().to_string())
}
pub struct Generate {