diff --git a/Cargo.lock b/Cargo.lock index f56714b..ed22777 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -569,7 +569,7 @@ dependencies = [ "encode_unicode", "libc", "once_cell", - "unicode-width", + "unicode-width 0.2.0", "windows-sys 0.59.0", ] @@ -1707,7 +1707,7 @@ dependencies = [ "console", "number_prefix", "portable-atomic", - "unicode-width", + "unicode-width 0.2.0", "web-time", ] @@ -1777,6 +1777,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "kdl" +version = "6.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12661358400b02cbbf1fbd05f0a483335490e8a6bd1867620f2eeb78f304a22f" +dependencies = [ + "miette", + "num", + "thiserror 1.0.69", + "winnow 0.6.24", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -1898,6 +1910,28 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "miette" +version = "7.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f98efec8807c63c752b5bd61f862c165c115b0a35685bdcfd9238c7aeb592b7" +dependencies = [ + "cfg-if", + "miette-derive", + "unicode-width 0.1.14", +] + +[[package]] +name = "miette-derive" +version = "7.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db5b29714e950dbb20d5e6f74f9dcec4edbcc1067bb7f8ed198c097b8c1a818b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "mime" version = "0.3.17" @@ -2725,6 +2759,7 @@ version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ + "indexmap", "itoa", "memchr", "ryu", @@ -2792,9 +2827,14 @@ dependencies = [ "glob", "hf-hub", "hora", + "kdl", + "regex", "serde", "serde_json", "tokenizers", + "tree-sitter", + "tree-sitter-go", + "tree-sitter-rust", ] [[package]] @@ -2851,6 +2891,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" + [[package]] name = "strsim" version = "0.11.1" @@ -3145,7 +3191,7 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ "indexmap", "toml_datetime", - "winnow", + "winnow 0.7.10", ] [[package]] @@ -3225,6 +3271,46 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tree-sitter" +version = "0.25.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7cf18d43cbf0bfca51f657132cc616a5097edc4424d538bae6fa60142eaf9f0" +dependencies = [ + "cc", + "regex", + "regex-syntax", + "serde_json", + "streaming-iterator", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-go" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b13d476345220dbe600147dd444165c5791bf85ef53e28acbedd46112ee18431" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-language" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4013970217383f67b18aef68f6fb2e8d409bc5755227092d32efb0422ba24b8" + +[[package]] +name = "tree-sitter-rust" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b9b18034c684a2420722be8b2a91c9c44f2546b631c039edf575ccba8c61be1" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "try-lock" version = "0.2.5" @@ -3279,6 +3365,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "unicode-width" version = "0.2.0" @@ -3792,6 +3884,15 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +[[package]] +name = "winnow" +version = "0.6.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +dependencies = [ + "memchr", +] + [[package]] name = "winnow" version = "0.7.10" diff --git a/Cargo.toml b/Cargo.toml index ae9b84f..d790d82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,10 @@ glob = "0.3.2" hf-hub = "0.4.2" hora = "0.1.1" kdl = "6.3.4" +regex = "1.11.1" serde = "1.0.219" serde_json = "1.0.140" tokenizers = "0.21.1" +tree-sitter = "0.25.6" +tree-sitter-go = "0.23.4" +tree-sitter-rust = "0.24.0" diff --git a/snippets/v1/fish/path-entries.kdl b/snippets/v1/fish/path-entries.kdl index 5f3e6f4..b47cb53 100644 --- a/snippets/v1/fish/path-entries.kdl +++ b/snippets/v1/fish/path-entries.kdl @@ -1,3 +1,3 @@ desc "display all the path entries" -body """printf "%s\n" $PATH""" +body #"printf "%s\n" $PATH"# diff --git a/snippets/v2/go/mutations.kdl b/snippets/v2/go/mutations.kdl new file mode 100644 index 0000000..2e399e1 --- /dev/null +++ b/snippets/v2/go/mutations.kdl @@ -0,0 +1,25 @@ +description "filepath base to parent's base" +mutation { + expression """ + (call_expression + function: (_) @func (#eq? @func "filepath.Base") + arguments: (_) @args + ) @root + """ + substitute { + literal "filepath.Base(filepath.Dir(filepath.Clean" + capture "args" + literal "))" + } +} + +mutation { + expression """ + ((interpreted_string_literal_content) @str + (#eq? @str "/home/h/signal/softiee") + ) @root + """ + substitute { + literal "/home/softiee/signal/h" + } +} diff --git a/src/main.rs b/src/main.rs index 2963dea..754a243 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use kdl::KdlDocument; use std::collections::HashMap; mod embed; mod v1; -// mod v2; +mod v2; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -69,8 +69,8 @@ async fn main() -> Result<()> { HNSWIndex::::new(dimension, ¶ms) }); - let doc_str = std::fs::read_to_string(path)?; - let doc: KdlDocument = doc_str.parse().context("failed to parse KDL")?; + let doc_str = std::fs::read_to_string(&path)?; + let doc: KdlDocument = doc_str.parse().context(format!("failed to parse KDL: {}", path.display()))?; let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else { continue; @@ -88,7 +88,29 @@ async fn main() -> Result<()> { .build(hora::core::metrics::Metric::Euclidean) .map_err(E::msg)?; } - let appstate = v1::api::AppState { dict, embed }; + + // v2 stuff + let mutations = v2::mutation::from_path("snippets/v2/go/mutations.kdl")?; + let mut v2_dict = HashMap::new(); + let dimension = 384; + let params = hora::index::hnsw_params::HNSWParams::::default(); + + let mut v2_index = HNSWIndex::::new(dimension, ¶ms); + v2_index + .add(&embed.embed(&mutations.description)?, 0) + .map_err(E::msg)?; + v2_index + .build(hora::core::metrics::Metric::Euclidean) + .map_err(E::msg)?; + let v2_mutations_collection = vec![mutations]; + v2_dict.insert("go".to_string(), v2_index); + + let appstate = v1::api::AppState { + dict, + embed, + v2_dict, + v2_mutations_collection, + }; let appstate_wrapped = web::Data::new(appstate.build()); @@ -97,6 +119,7 @@ async fn main() -> Result<()> { .app_data(appstate_wrapped.clone()) .service(v1::api::get_snippet) .service(v1::api::add_snippet) + .service(v2::api::get_snippet) }) .bind(("127.0.0.1", port))? .run() diff --git a/src/v1/api.rs b/src/v1/api.rs index f401049..845eb08 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -4,7 +4,7 @@ use std::{collections::HashMap, sync::Mutex}; use hora::index::hnsw_idx::HNSWIndex; use serde::{Deserialize, Serialize}; -use crate::embed; +use crate::{embed, v2::mutation::MutationCollection}; use super::errors::GetError; @@ -24,12 +24,14 @@ pub struct SnippetOnDisk { } pub struct AppStateWrapper { - inner: Mutex, + pub inner: Mutex, } pub struct AppState { pub dict: HashMap>, pub embed: embed::Embed, + pub v2_dict: HashMap>, + pub v2_mutations_collection: Vec, } impl AppState { @@ -93,7 +95,7 @@ pub(crate) async fn add_snippet( .or_insert_with(|| { let dimension = 384; let params = hora::index::hnsw_params::HNSWParams::::default(); - + HNSWIndex::::new(dimension, ¶ms) }); index.add(&embedding, snippet.body.clone()).unwrap(); diff --git a/src/v2/api.rs b/src/v2/api.rs new file mode 100644 index 0000000..4a6cc43 --- /dev/null +++ b/src/v2/api.rs @@ -0,0 +1,116 @@ +use crate::v1::api::AppStateWrapper; +use hora::core::ann_index::ANNIndex; +use tree_sitter::Parser; + +use serde::{Deserialize, Serialize}; + +use crate::{embed, v2::mutation}; + +use super::{errors::GetError, mutation::Mutation}; + +use actix_web::{Responder, post, web}; + +use anyhow::Result; +#[derive(Deserialize)] +pub struct SnippetRequest { + desc: String, + body: String, + top_k: Option, +} + +#[derive(Serialize)] +pub struct SnippetResponse { + id: usize, + snippet: Snippet, +} + +#[derive(Serialize, Deserialize)] +pub struct Snippet { + lang: String, + desc: String, + body: String, +} + +fn get_lang(s: &str) -> tree_sitter::Language { + match s { + "go" => tree_sitter_go::LANGUAGE, + "rust" => tree_sitter_rust::LANGUAGE, + _ => unreachable!(), + } + .into() +} + +#[post("/api/v2/get")] +pub(crate) async fn get_snippet( + data: web::Data, + snippet_request: web::Json, +) -> Result { + let Some((prompt, lang)) = snippet_request.desc.rsplit_once(" in ") else { + return Err(GetError::MissingSuffix); + }; + + let langfn = get_lang(lang); + + println!("{prompt:?}"); + + let Ok(mut appstate) = data.inner.lock() else { + return Err(GetError::Busy); + }; + + let Ok(target) = appstate.embed.embed(prompt) else { + return Err(GetError::EmbedFailed); + }; + + let mut parser = Parser::new(); + parser.set_language(&langfn).unwrap(); + + let source_code = std::fs::read_to_string("./example.go").unwrap(); + let source_bytes = source_code.as_bytes(); + let tree = parser.parse(&source_code, None).unwrap(); + let root_node = tree.root_node(); + + // search for k nearest neighbors + let closest: Vec = appstate.v2_dict[lang] + .search(&target, snippet_request.top_k.unwrap_or(1)) + .iter() + .map(|v| { + mutation::apply( + langfn.clone(), + snippet_request.body.as_bytes(), + root_node, + &appstate.v2_mutations_collection[*v], + ) + .expect(&format!("failed to apply mutations from collection {v}")) + }) + .collect(); + Ok(web::Json(closest)) +} + +// #[post("/api/v2/add")] +// pub(crate) async fn add_snippet( +// data: web::Data, +// snippet: web::Json, +// ) -> Result { +// let Ok(mut appstate) = data.inner.lock() else { +// return Err(GetError::Busy); +// }; +// let Ok(embedding) = appstate.embed.embed(&snippet.desc) else { +// return Err(GetError::EmbedFailed); +// }; +// let index = appstate +// .dict +// .entry(snippet.lang.clone()) +// .or_insert_with(|| { +// let dimension = 384; +// let params = hora::index::hnsw_params::HNSWParams::::default(); + +// HNSWIndex::::new(dimension, ¶ms) +// }); +// index.add(&embedding, snippet.body.clone()).unwrap(); +// index.build(hora::core::metrics::Metric::Euclidean).unwrap(); + +// Ok(format!( +// "{} {} {}", +// snippet.body, snippet.lang, snippet.desc +// )) +// } diff --git a/src/v2/errors.rs b/src/v2/errors.rs new file mode 100644 index 0000000..ea349fa --- /dev/null +++ b/src/v2/errors.rs @@ -0,0 +1,36 @@ +use actix_web::{ + HttpResponse, error, + http::{StatusCode, header::ContentType}, +}; +use derive_more::derive::{Display, Error}; +use serde_json::json; + +#[derive(Debug, Display, Error)] +pub enum GetError { + #[display("the server is busy. come back later.")] + Busy, + #[display("end your request with ` in somelang`.")] + MissingSuffix, + #[display("failed to embed your prompt.")] + EmbedFailed, +} + +impl error::ResponseError for GetError { + fn error_response(&self) -> HttpResponse { + let message = json!({ + "message": self.to_string(), + }) + .to_string(); + HttpResponse::build(self.status_code()) + .insert_header(ContentType::json()) + .body(message) + } + + fn status_code(&self) -> StatusCode { + match *self { + Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR, + Self::MissingSuffix => StatusCode::BAD_REQUEST, + Self::Busy => StatusCode::GATEWAY_TIMEOUT, + } + } +} diff --git a/src/v2/mod.rs b/src/v2/mod.rs new file mode 100644 index 0000000..8b9d27b --- /dev/null +++ b/src/v2/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod api; +pub(crate) mod errors; +pub(crate) mod mutation; diff --git a/src/v2/mutation.rs b/src/v2/mutation.rs new file mode 100644 index 0000000..3830eaf --- /dev/null +++ b/src/v2/mutation.rs @@ -0,0 +1,190 @@ +use std::collections::HashMap; +use std::path::Path; +use tree_sitter::{Language, Node, Query, QueryCursor, StreamingIterator}; + +use anyhow::{Result, bail}; +use kdl::KdlDocument; + +#[derive(Debug)] +pub struct Mutation { + pub expression: String, + pub substitute: Vec, +} + +pub struct MutationCollection { + pub description: String, + pub mutations: Vec, +} + +#[derive(Debug)] +pub enum Substitute { + Literal(String), + Capture(String), +} + +pub fn from_path>(path: P) -> Result { + let contents = std::fs::read_to_string(path)?; + let doc: KdlDocument = contents.parse()?; + let mut mutations = vec![]; + + let mut description = None; + + for node in doc.nodes() { + let node_name = node.name().value(); + + if node_name != "mutation" && node_name != "description" { + bail!("document root must only contain `mutation` or `description` nodes: got {node_name}"); + } + + if node_name == "description" { + description.replace( + node.entry(0) + .unwrap() + .value() + .as_string() + .unwrap() + .to_string(), + ); + continue; + } + + let node = node.children().unwrap(); + let Some(expression) = node.get_arg("expression").and_then(|v| v.as_string()) else { + bail!("mutation node must contain an expression"); + }; + let Some(substitute) = node.get("substitute") else { + bail!("mutation node must contain an substitute"); + }; + + let children = substitute.children().unwrap().nodes(); + let mut substitute = vec![]; + for child in children { + let attrib = child.entry(0).unwrap().value().as_string().unwrap(); + let substitutor = match child.name().value() { + "literal" => Substitute::Literal(attrib.to_string()), + "capture" => Substitute::Capture(attrib.to_string()), + _ => unreachable!(), + }; + + substitute.push(substitutor); + } + + let expression = expression.to_string(); + + mutations.push(Mutation { + expression, + substitute, + }) + } + + let Some(description) = description else { + bail!("mutation collection contains no `description`"); + }; + + Ok(MutationCollection { + description, + mutations, + }) +} + +pub fn apply( + lang: Language, + source_bytes: &[u8], + root_node: Node<'_>, + mutations: &MutationCollection, +) -> Result { + let mut split_ats = vec![]; + let mut query_result_map = HashMap::new(); + for mutation in &mutations.mutations { + let query_result = query(root_node, mutation.expression.as_str(), &lang, source_bytes); + eprintln!("{:?}", query_result); + split_ats.push(query_result.start); + split_ats.push(query_result.end); + + let mut ast_rewrite = String::default(); + for sub in &mutation.substitute { + ast_rewrite.push_str(match sub { + Substitute::Literal(attrib) => attrib, + Substitute::Capture(attrib) => &query_result.captures[attrib], + }) + } + eprintln!("{ast_rewrite:?}"); + + query_result_map.insert(query_result.start, ast_rewrite); + } + split_ats.sort(); + let splits = split_at_indices(source_bytes, &split_ats); + let mut output = String::default(); + for (i, split) in splits.indices.iter().zip(splits.values) { + let split = std::str::from_utf8(split)?; + output.push_str(query_result_map.get(i).map(|v| v.as_str()).unwrap_or(split)); + } + Ok(output) +} + +fn display_s_expr(node: Node<'_>) { + let exp = node.to_sexp(); + eprintln!("{exp:?}"); +} + +#[derive(Debug)] +struct QueryCooked { + captures: HashMap, + end: usize, + start: usize, +} + +pub struct SplitMap<'a> { + values: Vec<&'a [u8]>, + indices: Vec, +} + +fn split_at_indices<'a>(c: &'a [u8], idx: &[usize]) -> SplitMap<'a> { + let mut a = 0; + let mut values = vec![]; + let mut indices = vec![a]; + for &b in idx { + values.push(&c[a..b]); + a = b; + indices.push(a); + } + values.push(&c[a..]); + assert_eq!(values.len(), indices.len()); + SplitMap { values, indices } +} + +fn query<'a>(node: Node<'a>, expr: &'a str, lang: &Language, source_bytes: &[u8]) -> QueryCooked { + let query = Query::new(lang, expr).unwrap(); + + let mut qc = QueryCursor::new(); + let mut query_matches = qc.matches(&query, node, source_bytes); + + let capture_names = query.capture_names(); + let mut capture_cooked = HashMap::new(); + + let mut start = 0; + let mut end = 0; + + if let Some(matcha) = query_matches.next() { + for cap in matcha.captures { + let Some(name) = capture_names.get(cap.index as usize) else { + continue; + }; + if *name == "root" { + start = cap.node.start_byte(); + end = cap.node.end_byte(); + continue; + } + capture_cooked.insert( + name.to_string(), + cap.node.utf8_text(source_bytes).unwrap().to_string(), + ); + } + } + + QueryCooked { + start, + end, + captures: capture_cooked, + } +}