feat: scaffolding for v2 mutation api
This commit is contained in:
107
Cargo.lock
generated
107
Cargo.lock
generated
@@ -569,7 +569,7 @@ dependencies = [
|
||||
"encode_unicode",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"unicode-width",
|
||||
"unicode-width 0.2.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -1707,7 +1707,7 @@ dependencies = [
|
||||
"console",
|
||||
"number_prefix",
|
||||
"portable-atomic",
|
||||
"unicode-width",
|
||||
"unicode-width 0.2.0",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
@@ -1777,6 +1777,18 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kdl"
|
||||
version = "6.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12661358400b02cbbf1fbd05f0a483335490e8a6bd1867620f2eeb78f304a22f"
|
||||
dependencies = [
|
||||
"miette",
|
||||
"num",
|
||||
"thiserror 1.0.69",
|
||||
"winnow 0.6.24",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "language-tags"
|
||||
version = "0.3.2"
|
||||
@@ -1898,6 +1910,28 @@ dependencies = [
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miette"
|
||||
version = "7.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f98efec8807c63c752b5bd61f862c165c115b0a35685bdcfd9238c7aeb592b7"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"miette-derive",
|
||||
"unicode-width 0.1.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miette-derive"
|
||||
version = "7.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db5b29714e950dbb20d5e6f74f9dcec4edbcc1067bb7f8ed198c097b8c1a818b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -2725,6 +2759,7 @@ version = "1.0.140"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
@@ -2792,9 +2827,14 @@ dependencies = [
|
||||
"glob",
|
||||
"hf-hub",
|
||||
"hora",
|
||||
"kdl",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tree-sitter",
|
||||
"tree-sitter-go",
|
||||
"tree-sitter-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2851,6 +2891,12 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "streaming-iterator"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
@@ -3145,7 +3191,7 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"toml_datetime",
|
||||
"winnow",
|
||||
"winnow 0.7.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3225,6 +3271,46 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter"
|
||||
version = "0.25.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"regex",
|
||||
"regex-syntax",
|
||||
"serde_json",
|
||||
"streaming-iterator",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-go"
|
||||
version = "0.23.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b13d476345220dbe600147dd444165c5791bf85ef53e28acbedd46112ee18431"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-language"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4013970217383f67b18aef68f6fb2e8d409bc5755227092d32efb0422ba24b8"
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-rust"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b9b18034c684a2420722be8b2a91c9c44f2546b631c039edf575ccba8c61be1"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"tree-sitter-language",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.5"
|
||||
@@ -3279,6 +3365,12 @@ version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.0"
|
||||
@@ -3792,6 +3884,15 @@ version = "0.53.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.6.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.7.10"
|
||||
|
||||
@@ -15,6 +15,10 @@ glob = "0.3.2"
|
||||
hf-hub = "0.4.2"
|
||||
hora = "0.1.1"
|
||||
kdl = "6.3.4"
|
||||
regex = "1.11.1"
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.140"
|
||||
tokenizers = "0.21.1"
|
||||
tree-sitter = "0.25.6"
|
||||
tree-sitter-go = "0.23.4"
|
||||
tree-sitter-rust = "0.24.0"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
desc "display all the path entries"
|
||||
body """printf "%s\n" $PATH"""
|
||||
body #"printf "%s\n" $PATH"#
|
||||
|
||||
|
||||
25
snippets/v2/go/mutations.kdl
Normal file
25
snippets/v2/go/mutations.kdl
Normal file
@@ -0,0 +1,25 @@
|
||||
description "filepath base to parent's base"
|
||||
mutation {
|
||||
expression """
|
||||
(call_expression
|
||||
function: (_) @func (#eq? @func "filepath.Base")
|
||||
arguments: (_) @args
|
||||
) @root
|
||||
"""
|
||||
substitute {
|
||||
literal "filepath.Base(filepath.Dir(filepath.Clean"
|
||||
capture "args"
|
||||
literal "))"
|
||||
}
|
||||
}
|
||||
|
||||
mutation {
|
||||
expression """
|
||||
((interpreted_string_literal_content) @str
|
||||
(#eq? @str "/home/h/signal/softiee")
|
||||
) @root
|
||||
"""
|
||||
substitute {
|
||||
literal "/home/softiee/signal/h"
|
||||
}
|
||||
}
|
||||
31
src/main.rs
31
src/main.rs
@@ -7,7 +7,7 @@ use kdl::KdlDocument;
|
||||
use std::collections::HashMap;
|
||||
mod embed;
|
||||
mod v1;
|
||||
// mod v2;
|
||||
mod v2;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@@ -69,8 +69,8 @@ async fn main() -> Result<()> {
|
||||
HNSWIndex::<f32, String>::new(dimension, ¶ms)
|
||||
});
|
||||
|
||||
let doc_str = std::fs::read_to_string(path)?;
|
||||
let doc: KdlDocument = doc_str.parse().context("failed to parse KDL")?;
|
||||
let doc_str = std::fs::read_to_string(&path)?;
|
||||
let doc: KdlDocument = doc_str.parse().context(format!("failed to parse KDL: {}", path.display()))?;
|
||||
|
||||
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
@@ -88,7 +88,29 @@ async fn main() -> Result<()> {
|
||||
.build(hora::core::metrics::Metric::Euclidean)
|
||||
.map_err(E::msg)?;
|
||||
}
|
||||
let appstate = v1::api::AppState { dict, embed };
|
||||
|
||||
// v2 stuff
|
||||
let mutations = v2::mutation::from_path("snippets/v2/go/mutations.kdl")?;
|
||||
let mut v2_dict = HashMap::new();
|
||||
let dimension = 384;
|
||||
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
||||
|
||||
let mut v2_index = HNSWIndex::<f32, usize>::new(dimension, ¶ms);
|
||||
v2_index
|
||||
.add(&embed.embed(&mutations.description)?, 0)
|
||||
.map_err(E::msg)?;
|
||||
v2_index
|
||||
.build(hora::core::metrics::Metric::Euclidean)
|
||||
.map_err(E::msg)?;
|
||||
let v2_mutations_collection = vec![mutations];
|
||||
v2_dict.insert("go".to_string(), v2_index);
|
||||
|
||||
let appstate = v1::api::AppState {
|
||||
dict,
|
||||
embed,
|
||||
v2_dict,
|
||||
v2_mutations_collection,
|
||||
};
|
||||
|
||||
let appstate_wrapped = web::Data::new(appstate.build());
|
||||
|
||||
@@ -97,6 +119,7 @@ async fn main() -> Result<()> {
|
||||
.app_data(appstate_wrapped.clone())
|
||||
.service(v1::api::get_snippet)
|
||||
.service(v1::api::add_snippet)
|
||||
.service(v2::api::get_snippet)
|
||||
})
|
||||
.bind(("127.0.0.1", port))?
|
||||
.run()
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{collections::HashMap, sync::Mutex};
|
||||
use hora::index::hnsw_idx::HNSWIndex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::embed;
|
||||
use crate::{embed, v2::mutation::MutationCollection};
|
||||
|
||||
use super::errors::GetError;
|
||||
|
||||
@@ -24,12 +24,14 @@ pub struct SnippetOnDisk {
|
||||
}
|
||||
|
||||
pub struct AppStateWrapper {
|
||||
inner: Mutex<AppState>,
|
||||
pub inner: Mutex<AppState>,
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
pub dict: HashMap<String, HNSWIndex<f32, String>>,
|
||||
pub embed: embed::Embed,
|
||||
pub v2_dict: HashMap<String, HNSWIndex<f32, usize>>,
|
||||
pub v2_mutations_collection: Vec<MutationCollection>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -93,7 +95,7 @@ pub(crate) async fn add_snippet(
|
||||
.or_insert_with(|| {
|
||||
let dimension = 384;
|
||||
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
||||
|
||||
|
||||
HNSWIndex::<f32, String>::new(dimension, ¶ms)
|
||||
});
|
||||
index.add(&embedding, snippet.body.clone()).unwrap();
|
||||
|
||||
116
src/v2/api.rs
Normal file
116
src/v2/api.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use crate::v1::api::AppStateWrapper;
|
||||
use hora::core::ann_index::ANNIndex;
|
||||
use tree_sitter::Parser;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{embed, v2::mutation};
|
||||
|
||||
use super::{errors::GetError, mutation::Mutation};
|
||||
|
||||
use actix_web::{Responder, post, web};
|
||||
|
||||
use anyhow::Result;
|
||||
#[derive(Deserialize)]
|
||||
pub struct SnippetRequest {
|
||||
desc: String,
|
||||
body: String,
|
||||
top_k: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SnippetResponse {
|
||||
id: usize,
|
||||
snippet: Snippet,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Snippet {
|
||||
lang: String,
|
||||
desc: String,
|
||||
body: String,
|
||||
}
|
||||
|
||||
fn get_lang(s: &str) -> tree_sitter::Language {
|
||||
match s {
|
||||
"go" => tree_sitter_go::LANGUAGE,
|
||||
"rust" => tree_sitter_rust::LANGUAGE,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
#[post("/api/v2/get")]
|
||||
pub(crate) async fn get_snippet(
|
||||
data: web::Data<AppStateWrapper>,
|
||||
snippet_request: web::Json<SnippetRequest>,
|
||||
) -> Result<impl Responder, GetError> {
|
||||
let Some((prompt, lang)) = snippet_request.desc.rsplit_once(" in ") else {
|
||||
return Err(GetError::MissingSuffix);
|
||||
};
|
||||
|
||||
let langfn = get_lang(lang);
|
||||
|
||||
println!("{prompt:?}");
|
||||
|
||||
let Ok(mut appstate) = data.inner.lock() else {
|
||||
return Err(GetError::Busy);
|
||||
};
|
||||
|
||||
let Ok(target) = appstate.embed.embed(prompt) else {
|
||||
return Err(GetError::EmbedFailed);
|
||||
};
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(&langfn).unwrap();
|
||||
|
||||
let source_code = std::fs::read_to_string("./example.go").unwrap();
|
||||
let source_bytes = source_code.as_bytes();
|
||||
let tree = parser.parse(&source_code, None).unwrap();
|
||||
let root_node = tree.root_node();
|
||||
|
||||
// search for k nearest neighbors
|
||||
let closest: Vec<String> = appstate.v2_dict[lang]
|
||||
.search(&target, snippet_request.top_k.unwrap_or(1))
|
||||
.iter()
|
||||
.map(|v| {
|
||||
mutation::apply(
|
||||
langfn.clone(),
|
||||
snippet_request.body.as_bytes(),
|
||||
root_node,
|
||||
&appstate.v2_mutations_collection[*v],
|
||||
)
|
||||
.expect(&format!("failed to apply mutations from collection {v}"))
|
||||
})
|
||||
.collect();
|
||||
Ok(web::Json(closest))
|
||||
}
|
||||
|
||||
// #[post("/api/v2/add")]
|
||||
// pub(crate) async fn add_snippet(
|
||||
// data: web::Data<AppStateWrapper>,
|
||||
// snippet: web::Json<Snippet>,
|
||||
// ) -> Result<impl Responder, GetError> {
|
||||
// let Ok(mut appstate) = data.inner.lock() else {
|
||||
// return Err(GetError::Busy);
|
||||
// };
|
||||
// let Ok(embedding) = appstate.embed.embed(&snippet.desc) else {
|
||||
// return Err(GetError::EmbedFailed);
|
||||
// };
|
||||
// let index = appstate
|
||||
// .dict
|
||||
// .entry(snippet.lang.clone())
|
||||
// .or_insert_with(|| {
|
||||
// let dimension = 384;
|
||||
// let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
||||
|
||||
// HNSWIndex::<f32, String>::new(dimension, ¶ms)
|
||||
// });
|
||||
// index.add(&embedding, snippet.body.clone()).unwrap();
|
||||
// index.build(hora::core::metrics::Metric::Euclidean).unwrap();
|
||||
|
||||
// Ok(format!(
|
||||
// "{} {} {}",
|
||||
// snippet.body, snippet.lang, snippet.desc
|
||||
// ))
|
||||
// }
|
||||
36
src/v2/errors.rs
Normal file
36
src/v2/errors.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use actix_web::{
|
||||
HttpResponse, error,
|
||||
http::{StatusCode, header::ContentType},
|
||||
};
|
||||
use derive_more::derive::{Display, Error};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Display, Error)]
|
||||
pub enum GetError {
|
||||
#[display("the server is busy. come back later.")]
|
||||
Busy,
|
||||
#[display("end your request with ` in somelang`.")]
|
||||
MissingSuffix,
|
||||
#[display("failed to embed your prompt.")]
|
||||
EmbedFailed,
|
||||
}
|
||||
|
||||
impl error::ResponseError for GetError {
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
let message = json!({
|
||||
"message": self.to_string(),
|
||||
})
|
||||
.to_string();
|
||||
HttpResponse::build(self.status_code())
|
||||
.insert_header(ContentType::json())
|
||||
.body(message)
|
||||
}
|
||||
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match *self {
|
||||
Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::MissingSuffix => StatusCode::BAD_REQUEST,
|
||||
Self::Busy => StatusCode::GATEWAY_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
3
src/v2/mod.rs
Normal file
3
src/v2/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub(crate) mod api;
|
||||
pub(crate) mod errors;
|
||||
pub(crate) mod mutation;
|
||||
190
src/v2/mutation.rs
Normal file
190
src/v2/mutation.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator};
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
use kdl::KdlDocument;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Mutation {
|
||||
pub expression: String,
|
||||
pub substitute: Vec<Substitute>,
|
||||
}
|
||||
|
||||
pub struct MutationCollection {
|
||||
pub description: String,
|
||||
pub mutations: Vec<Mutation>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Substitute {
|
||||
Literal(String),
|
||||
Capture(String),
|
||||
}
|
||||
|
||||
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
|
||||
let contents = std::fs::read_to_string(path)?;
|
||||
let doc: KdlDocument = contents.parse()?;
|
||||
let mut mutations = vec![];
|
||||
|
||||
let mut description = None;
|
||||
|
||||
for node in doc.nodes() {
|
||||
let node_name = node.name().value();
|
||||
|
||||
if node_name != "mutation" && node_name != "description" {
|
||||
bail!("document root must only contain `mutation` or `description` nodes: got {node_name}");
|
||||
}
|
||||
|
||||
if node_name == "description" {
|
||||
description.replace(
|
||||
node.entry(0)
|
||||
.unwrap()
|
||||
.value()
|
||||
.as_string()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let node = node.children().unwrap();
|
||||
let Some(expression) = node.get_arg("expression").and_then(|v| v.as_string()) else {
|
||||
bail!("mutation node must contain an expression");
|
||||
};
|
||||
let Some(substitute) = node.get("substitute") else {
|
||||
bail!("mutation node must contain an substitute");
|
||||
};
|
||||
|
||||
let children = substitute.children().unwrap().nodes();
|
||||
let mut substitute = vec![];
|
||||
for child in children {
|
||||
let attrib = child.entry(0).unwrap().value().as_string().unwrap();
|
||||
let substitutor = match child.name().value() {
|
||||
"literal" => Substitute::Literal(attrib.to_string()),
|
||||
"capture" => Substitute::Capture(attrib.to_string()),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
substitute.push(substitutor);
|
||||
}
|
||||
|
||||
let expression = expression.to_string();
|
||||
|
||||
mutations.push(Mutation {
|
||||
expression,
|
||||
substitute,
|
||||
})
|
||||
}
|
||||
|
||||
let Some(description) = description else {
|
||||
bail!("mutation collection contains no `description`");
|
||||
};
|
||||
|
||||
Ok(MutationCollection {
|
||||
description,
|
||||
mutations,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn apply(
|
||||
lang: Language,
|
||||
source_bytes: &[u8],
|
||||
root_node: Node<'_>,
|
||||
mutations: &MutationCollection,
|
||||
) -> Result<String, anyhow::Error> {
|
||||
let mut split_ats = vec![];
|
||||
let mut query_result_map = HashMap::new();
|
||||
for mutation in &mutations.mutations {
|
||||
let query_result = query(root_node, mutation.expression.as_str(), &lang, source_bytes);
|
||||
eprintln!("{:?}", query_result);
|
||||
split_ats.push(query_result.start);
|
||||
split_ats.push(query_result.end);
|
||||
|
||||
let mut ast_rewrite = String::default();
|
||||
for sub in &mutation.substitute {
|
||||
ast_rewrite.push_str(match sub {
|
||||
Substitute::Literal(attrib) => attrib,
|
||||
Substitute::Capture(attrib) => &query_result.captures[attrib],
|
||||
})
|
||||
}
|
||||
eprintln!("{ast_rewrite:?}");
|
||||
|
||||
query_result_map.insert(query_result.start, ast_rewrite);
|
||||
}
|
||||
split_ats.sort();
|
||||
let splits = split_at_indices(source_bytes, &split_ats);
|
||||
let mut output = String::default();
|
||||
for (i, split) in splits.indices.iter().zip(splits.values) {
|
||||
let split = std::str::from_utf8(split)?;
|
||||
output.push_str(query_result_map.get(i).map(|v| v.as_str()).unwrap_or(split));
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn display_s_expr(node: Node<'_>) {
|
||||
let exp = node.to_sexp();
|
||||
eprintln!("{exp:?}");
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct QueryCooked {
|
||||
captures: HashMap<String, String>,
|
||||
end: usize,
|
||||
start: usize,
|
||||
}
|
||||
|
||||
pub struct SplitMap<'a> {
|
||||
values: Vec<&'a [u8]>,
|
||||
indices: Vec<usize>,
|
||||
}
|
||||
|
||||
fn split_at_indices<'a>(c: &'a [u8], idx: &[usize]) -> SplitMap<'a> {
|
||||
let mut a = 0;
|
||||
let mut values = vec![];
|
||||
let mut indices = vec![a];
|
||||
for &b in idx {
|
||||
values.push(&c[a..b]);
|
||||
a = b;
|
||||
indices.push(a);
|
||||
}
|
||||
values.push(&c[a..]);
|
||||
assert_eq!(values.len(), indices.len());
|
||||
SplitMap { values, indices }
|
||||
}
|
||||
|
||||
fn query<'a>(node: Node<'a>, expr: &'a str, lang: &Language, source_bytes: &[u8]) -> QueryCooked {
|
||||
let query = Query::new(lang, expr).unwrap();
|
||||
|
||||
let mut qc = QueryCursor::new();
|
||||
let mut query_matches = qc.matches(&query, node, source_bytes);
|
||||
|
||||
let capture_names = query.capture_names();
|
||||
let mut capture_cooked = HashMap::new();
|
||||
|
||||
let mut start = 0;
|
||||
let mut end = 0;
|
||||
|
||||
if let Some(matcha) = query_matches.next() {
|
||||
for cap in matcha.captures {
|
||||
let Some(name) = capture_names.get(cap.index as usize) else {
|
||||
continue;
|
||||
};
|
||||
if *name == "root" {
|
||||
start = cap.node.start_byte();
|
||||
end = cap.node.end_byte();
|
||||
continue;
|
||||
}
|
||||
capture_cooked.insert(
|
||||
name.to_string(),
|
||||
cap.node.utf8_text(source_bytes).unwrap().to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
QueryCooked {
|
||||
start,
|
||||
end,
|
||||
captures: capture_cooked,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user