feat: scaffolding for v2 mutation api

This commit is contained in:
Himadri Bhattacharjee
2025-06-19 19:27:36 +05:30
parent 3b6039a547
commit 7ca19deff6
10 changed files with 511 additions and 11 deletions

107
Cargo.lock generated
View File

@@ -569,7 +569,7 @@ dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"unicode-width 0.2.0",
"windows-sys 0.59.0",
]
@@ -1707,7 +1707,7 @@ dependencies = [
"console",
"number_prefix",
"portable-atomic",
"unicode-width",
"unicode-width 0.2.0",
"web-time",
]
@@ -1777,6 +1777,18 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "kdl"
version = "6.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12661358400b02cbbf1fbd05f0a483335490e8a6bd1867620f2eeb78f304a22f"
dependencies = [
"miette",
"num",
"thiserror 1.0.69",
"winnow 0.6.24",
]
[[package]]
name = "language-tags"
version = "0.3.2"
@@ -1898,6 +1910,28 @@ dependencies = [
"stable_deref_trait",
]
[[package]]
name = "miette"
version = "7.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f98efec8807c63c752b5bd61f862c165c115b0a35685bdcfd9238c7aeb592b7"
dependencies = [
"cfg-if",
"miette-derive",
"unicode-width 0.1.14",
]
[[package]]
name = "miette-derive"
version = "7.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db5b29714e950dbb20d5e6f74f9dcec4edbcc1067bb7f8ed198c097b8c1a818b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "mime"
version = "0.3.17"
@@ -2725,6 +2759,7 @@ version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
"indexmap",
"itoa",
"memchr",
"ryu",
@@ -2792,9 +2827,14 @@ dependencies = [
"glob",
"hf-hub",
"hora",
"kdl",
"regex",
"serde",
"serde_json",
"tokenizers",
"tree-sitter",
"tree-sitter-go",
"tree-sitter-rust",
]
[[package]]
@@ -2851,6 +2891,12 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520"
[[package]]
name = "strsim"
version = "0.11.1"
@@ -3145,7 +3191,7 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a"
dependencies = [
"indexmap",
"toml_datetime",
"winnow",
"winnow 0.7.10",
]
[[package]]
@@ -3225,6 +3271,46 @@ dependencies = [
"once_cell",
]
[[package]]
name = "tree-sitter"
version = "0.25.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0"
dependencies = [
"cc",
"regex",
"regex-syntax",
"serde_json",
"streaming-iterator",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-go"
version = "0.23.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b13d476345220dbe600147dd444165c5791bf85ef53e28acbedd46112ee18431"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "tree-sitter-language"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4013970217383f67b18aef68f6fb2e8d409bc5755227092d32efb0422ba24b8"
[[package]]
name = "tree-sitter-rust"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b9b18034c684a2420722be8b2a91c9c44f2546b631c039edf575ccba8c61be1"
dependencies = [
"cc",
"tree-sitter-language",
]
[[package]]
name = "try-lock"
version = "0.2.5"
@@ -3279,6 +3365,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
[[package]]
name = "unicode-width"
version = "0.2.0"
@@ -3792,6 +3884,15 @@ version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
[[package]]
name = "winnow"
version = "0.6.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a"
dependencies = [
"memchr",
]
[[package]]
name = "winnow"
version = "0.7.10"

View File

@@ -15,6 +15,10 @@ glob = "0.3.2"
hf-hub = "0.4.2"
hora = "0.1.1"
kdl = "6.3.4"
regex = "1.11.1"
serde = "1.0.219"
serde_json = "1.0.140"
tokenizers = "0.21.1"
tree-sitter = "0.25.6"
tree-sitter-go = "0.23.4"
tree-sitter-rust = "0.24.0"

View File

@@ -1,3 +1,3 @@
desc "display all the path entries"
body """printf "%s\n" $PATH"""
body #"printf "%s\n" $PATH"#

View File

@@ -0,0 +1,25 @@
description "filepath base to parent's base"
mutation {
expression """
(call_expression
function: (_) @func (#eq? @func "filepath.Base")
arguments: (_) @args
) @root
"""
substitute {
literal "filepath.Base(filepath.Dir(filepath.Clean"
capture "args"
literal "))"
}
}
mutation {
expression """
((interpreted_string_literal_content) @str
(#eq? @str "/home/h/signal/softiee")
) @root
"""
substitute {
literal "/home/softiee/signal/h"
}
}

View File

@@ -7,7 +7,7 @@ use kdl::KdlDocument;
use std::collections::HashMap;
mod embed;
mod v1;
// mod v2;
mod v2;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
@@ -69,8 +69,8 @@ async fn main() -> Result<()> {
HNSWIndex::<f32, String>::new(dimension, &params)
});
let doc_str = std::fs::read_to_string(path)?;
let doc: KdlDocument = doc_str.parse().context("failed to parse KDL")?;
let doc_str = std::fs::read_to_string(&path)?;
let doc: KdlDocument = doc_str.parse().context(format!("failed to parse KDL: {}", path.display()))?;
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
continue;
@@ -88,7 +88,29 @@ async fn main() -> Result<()> {
.build(hora::core::metrics::Metric::Euclidean)
.map_err(E::msg)?;
}
let appstate = v1::api::AppState { dict, embed };
// v2 stuff
let mutations = v2::mutation::from_path("snippets/v2/go/mutations.kdl")?;
let mut v2_dict = HashMap::new();
let dimension = 384;
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
let mut v2_index = HNSWIndex::<f32, usize>::new(dimension, &params);
v2_index
.add(&embed.embed(&mutations.description)?, 0)
.map_err(E::msg)?;
v2_index
.build(hora::core::metrics::Metric::Euclidean)
.map_err(E::msg)?;
let v2_mutations_collection = vec![mutations];
v2_dict.insert("go".to_string(), v2_index);
let appstate = v1::api::AppState {
dict,
embed,
v2_dict,
v2_mutations_collection,
};
let appstate_wrapped = web::Data::new(appstate.build());
@@ -97,6 +119,7 @@ async fn main() -> Result<()> {
.app_data(appstate_wrapped.clone())
.service(v1::api::get_snippet)
.service(v1::api::add_snippet)
.service(v2::api::get_snippet)
})
.bind(("127.0.0.1", port))?
.run()

View File

@@ -4,7 +4,7 @@ use std::{collections::HashMap, sync::Mutex};
use hora::index::hnsw_idx::HNSWIndex;
use serde::{Deserialize, Serialize};
use crate::embed;
use crate::{embed, v2::mutation::MutationCollection};
use super::errors::GetError;
@@ -24,12 +24,14 @@ pub struct SnippetOnDisk {
}
pub struct AppStateWrapper {
inner: Mutex<AppState>,
pub inner: Mutex<AppState>,
}
pub struct AppState {
pub dict: HashMap<String, HNSWIndex<f32, String>>,
pub embed: embed::Embed,
pub v2_dict: HashMap<String, HNSWIndex<f32, usize>>,
pub v2_mutations_collection: Vec<MutationCollection>,
}
impl AppState {
@@ -93,7 +95,7 @@ pub(crate) async fn add_snippet(
.or_insert_with(|| {
let dimension = 384;
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
HNSWIndex::<f32, String>::new(dimension, &params)
});
index.add(&embedding, snippet.body.clone()).unwrap();

116
src/v2/api.rs Normal file
View File

@@ -0,0 +1,116 @@
use crate::v1::api::AppStateWrapper;
use hora::core::ann_index::ANNIndex;
use tree_sitter::Parser;
use serde::{Deserialize, Serialize};
use crate::{embed, v2::mutation};
use super::{errors::GetError, mutation::Mutation};
use actix_web::{Responder, post, web};
use anyhow::Result;
#[derive(Deserialize)]
pub struct SnippetRequest {
desc: String,
body: String,
top_k: Option<usize>,
}
#[derive(Serialize)]
pub struct SnippetResponse {
id: usize,
snippet: Snippet,
}
#[derive(Serialize, Deserialize)]
pub struct Snippet {
lang: String,
desc: String,
body: String,
}
fn get_lang(s: &str) -> tree_sitter::Language {
match s {
"go" => tree_sitter_go::LANGUAGE,
"rust" => tree_sitter_rust::LANGUAGE,
_ => unreachable!(),
}
.into()
}
#[post("/api/v2/get")]
pub(crate) async fn get_snippet(
data: web::Data<AppStateWrapper>,
snippet_request: web::Json<SnippetRequest>,
) -> Result<impl Responder, GetError> {
let Some((prompt, lang)) = snippet_request.desc.rsplit_once(" in ") else {
return Err(GetError::MissingSuffix);
};
let langfn = get_lang(lang);
println!("{prompt:?}");
let Ok(mut appstate) = data.inner.lock() else {
return Err(GetError::Busy);
};
let Ok(target) = appstate.embed.embed(prompt) else {
return Err(GetError::EmbedFailed);
};
let mut parser = Parser::new();
parser.set_language(&langfn).unwrap();
let source_code = std::fs::read_to_string("./example.go").unwrap();
let source_bytes = source_code.as_bytes();
let tree = parser.parse(&source_code, None).unwrap();
let root_node = tree.root_node();
// search for k nearest neighbors
let closest: Vec<String> = appstate.v2_dict[lang]
.search(&target, snippet_request.top_k.unwrap_or(1))
.iter()
.map(|v| {
mutation::apply(
langfn.clone(),
snippet_request.body.as_bytes(),
root_node,
&appstate.v2_mutations_collection[*v],
)
.expect(&format!("failed to apply mutations from collection {v}"))
})
.collect();
Ok(web::Json(closest))
}
// #[post("/api/v2/add")]
// pub(crate) async fn add_snippet(
// data: web::Data<AppStateWrapper>,
// snippet: web::Json<Snippet>,
// ) -> Result<impl Responder, GetError> {
// let Ok(mut appstate) = data.inner.lock() else {
// return Err(GetError::Busy);
// };
// let Ok(embedding) = appstate.embed.embed(&snippet.desc) else {
// return Err(GetError::EmbedFailed);
// };
// let index = appstate
// .dict
// .entry(snippet.lang.clone())
// .or_insert_with(|| {
// let dimension = 384;
// let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
// HNSWIndex::<f32, String>::new(dimension, &params)
// });
// index.add(&embedding, snippet.body.clone()).unwrap();
// index.build(hora::core::metrics::Metric::Euclidean).unwrap();
// Ok(format!(
// "{} {} {}",
// snippet.body, snippet.lang, snippet.desc
// ))
// }

36
src/v2/errors.rs Normal file
View File

@@ -0,0 +1,36 @@
use actix_web::{
HttpResponse, error,
http::{StatusCode, header::ContentType},
};
use derive_more::derive::{Display, Error};
use serde_json::json;
#[derive(Debug, Display, Error)]
pub enum GetError {
#[display("the server is busy. come back later.")]
Busy,
#[display("end your request with ` in somelang`.")]
MissingSuffix,
#[display("failed to embed your prompt.")]
EmbedFailed,
}
impl error::ResponseError for GetError {
fn error_response(&self) -> HttpResponse {
let message = json!({
"message": self.to_string(),
})
.to_string();
HttpResponse::build(self.status_code())
.insert_header(ContentType::json())
.body(message)
}
fn status_code(&self) -> StatusCode {
match *self {
Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR,
Self::MissingSuffix => StatusCode::BAD_REQUEST,
Self::Busy => StatusCode::GATEWAY_TIMEOUT,
}
}
}

3
src/v2/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub(crate) mod api;
pub(crate) mod errors;
pub(crate) mod mutation;

190
src/v2/mutation.rs Normal file
View File

@@ -0,0 +1,190 @@
use std::collections::HashMap;
use std::path::Path;
use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator};
use anyhow::{Result, bail};
use kdl::KdlDocument;
#[derive(Debug)]
pub struct Mutation {
pub expression: String,
pub substitute: Vec<Substitute>,
}
pub struct MutationCollection {
pub description: String,
pub mutations: Vec<Mutation>,
}
#[derive(Debug)]
pub enum Substitute {
Literal(String),
Capture(String),
}
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<MutationCollection> {
let contents = std::fs::read_to_string(path)?;
let doc: KdlDocument = contents.parse()?;
let mut mutations = vec![];
let mut description = None;
for node in doc.nodes() {
let node_name = node.name().value();
if node_name != "mutation" && node_name != "description" {
bail!("document root must only contain `mutation` or `description` nodes: got {node_name}");
}
if node_name == "description" {
description.replace(
node.entry(0)
.unwrap()
.value()
.as_string()
.unwrap()
.to_string(),
);
continue;
}
let node = node.children().unwrap();
let Some(expression) = node.get_arg("expression").and_then(|v| v.as_string()) else {
bail!("mutation node must contain an expression");
};
let Some(substitute) = node.get("substitute") else {
bail!("mutation node must contain an substitute");
};
let children = substitute.children().unwrap().nodes();
let mut substitute = vec![];
for child in children {
let attrib = child.entry(0).unwrap().value().as_string().unwrap();
let substitutor = match child.name().value() {
"literal" => Substitute::Literal(attrib.to_string()),
"capture" => Substitute::Capture(attrib.to_string()),
_ => unreachable!(),
};
substitute.push(substitutor);
}
let expression = expression.to_string();
mutations.push(Mutation {
expression,
substitute,
})
}
let Some(description) = description else {
bail!("mutation collection contains no `description`");
};
Ok(MutationCollection {
description,
mutations,
})
}
pub fn apply(
lang: Language,
source_bytes: &[u8],
root_node: Node<'_>,
mutations: &MutationCollection,
) -> Result<String, anyhow::Error> {
let mut split_ats = vec![];
let mut query_result_map = HashMap::new();
for mutation in &mutations.mutations {
let query_result = query(root_node, mutation.expression.as_str(), &lang, source_bytes);
eprintln!("{:?}", query_result);
split_ats.push(query_result.start);
split_ats.push(query_result.end);
let mut ast_rewrite = String::default();
for sub in &mutation.substitute {
ast_rewrite.push_str(match sub {
Substitute::Literal(attrib) => attrib,
Substitute::Capture(attrib) => &query_result.captures[attrib],
})
}
eprintln!("{ast_rewrite:?}");
query_result_map.insert(query_result.start, ast_rewrite);
}
split_ats.sort();
let splits = split_at_indices(source_bytes, &split_ats);
let mut output = String::default();
for (i, split) in splits.indices.iter().zip(splits.values) {
let split = std::str::from_utf8(split)?;
output.push_str(query_result_map.get(i).map(|v| v.as_str()).unwrap_or(split));
}
Ok(output)
}
fn display_s_expr(node: Node<'_>) {
let exp = node.to_sexp();
eprintln!("{exp:?}");
}
#[derive(Debug)]
struct QueryCooked {
captures: HashMap<String, String>,
end: usize,
start: usize,
}
pub struct SplitMap<'a> {
values: Vec<&'a [u8]>,
indices: Vec<usize>,
}
fn split_at_indices<'a>(c: &'a [u8], idx: &[usize]) -> SplitMap<'a> {
let mut a = 0;
let mut values = vec![];
let mut indices = vec![a];
for &b in idx {
values.push(&c[a..b]);
a = b;
indices.push(a);
}
values.push(&c[a..]);
assert_eq!(values.len(), indices.len());
SplitMap { values, indices }
}
fn query<'a>(node: Node<'a>, expr: &'a str, lang: &Language, source_bytes: &[u8]) -> QueryCooked {
let query = Query::new(lang, expr).unwrap();
let mut qc = QueryCursor::new();
let mut query_matches = qc.matches(&query, node, source_bytes);
let capture_names = query.capture_names();
let mut capture_cooked = HashMap::new();
let mut start = 0;
let mut end = 0;
if let Some(matcha) = query_matches.next() {
for cap in matcha.captures {
let Some(name) = capture_names.get(cap.index as usize) else {
continue;
};
if *name == "root" {
start = cap.node.start_byte();
end = cap.node.end_byte();
continue;
}
capture_cooked.insert(
name.to_string(),
cap.node.utf8_text(source_bytes).unwrap().to_string(),
);
}
}
QueryCooked {
start,
end,
captures: capture_cooked,
}
}