diff --git a/src/main.rs b/src/main.rs index 754a243..80ae714 100644 --- a/src/main.rs +++ b/src/main.rs @@ -90,20 +90,39 @@ async fn main() -> Result<()> { } // v2 stuff - let mutations = v2::mutation::from_path("snippets/v2/go/mutations.kdl")?; - let mut v2_dict = HashMap::new(); - let dimension = 384; - let params = hora::index::hnsw_params::HNSWParams::::default(); + let paths = glob::glob("./snippets/v2/*/*.kdl")?; + let mut v2_dict = HashMap::new(); + let mut v2_mutations_collection = vec![]; + for (i, path) in paths.enumerate() { + let path = path?; + let parent = path + .components() + .rev() + .nth(1) + .unwrap() + .as_os_str() + .to_string_lossy() + .to_string(); - let mut v2_index = HNSWIndex::::new(dimension, ¶ms); - v2_index - .add(&embed.embed(&mutations.description)?, 0) - .map_err(E::msg)?; - v2_index - .build(hora::core::metrics::Metric::Euclidean) - .map_err(E::msg)?; - let v2_mutations_collection = vec![mutations]; - v2_dict.insert("go".to_string(), v2_index); + let mutations = v2::mutation::from_path(path)?; + let current_lang_index = v2_dict.entry(parent).or_insert_with(|| { + let dimension = 384; + let params = hora::index::hnsw_params::HNSWParams::::default(); + + HNSWIndex::::new(dimension, ¶ms) + }); + + current_lang_index + .add(&embed.embed(&mutations.description)?, i) + .map_err(E::msg)?; + v2_mutations_collection.push(mutations); + } + + for index in v2_dict.values_mut() { + index + .build(hora::core::metrics::Metric::Euclidean) + .map_err(E::msg)?; + } let appstate = v1::api::AppState { dict, diff --git a/src/v1/api.rs b/src/v1/api.rs index 845eb08..6e9609e 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -3,26 +3,17 @@ use std::{collections::HashMap, sync::Mutex}; use hora::index::hnsw_idx::HNSWIndex; use serde::{Deserialize, Serialize}; - use crate::{embed, v2::mutation::MutationCollection}; - use super::errors::GetError; - +use anyhow::Result; use actix_web::{Responder, post, web}; -use anyhow::Result; #[derive(Deserialize)] pub struct SnippetRequest { desc: String, top_k: Option, } -#[derive(Deserialize, Debug)] -pub struct SnippetOnDisk { - pub body: String, - pub desc: String, -} - pub struct AppStateWrapper { pub inner: Mutex, } diff --git a/src/v2/api.rs b/src/v2/api.rs index 4a6cc43..77e6d8a 100644 --- a/src/v2/api.rs +++ b/src/v2/api.rs @@ -4,9 +4,8 @@ use tree_sitter::Parser; use serde::{Deserialize, Serialize}; -use crate::{embed, v2::mutation}; -use super::{errors::GetError, mutation::Mutation}; +use super::{errors::GetError, mutation}; use actix_web::{Responder, post, web}; @@ -31,13 +30,13 @@ pub struct Snippet { body: String, } -fn get_lang(s: &str) -> tree_sitter::Language { - match s { +fn get_lang(s: &str) -> Result { + Ok(match s { "go" => tree_sitter_go::LANGUAGE, "rust" => tree_sitter_rust::LANGUAGE, - _ => unreachable!(), + _ => return Err(GetError::UnknownLang), } - .into() + .into()) } #[post("/api/v2/get")] @@ -49,7 +48,7 @@ pub(crate) async fn get_snippet( return Err(GetError::MissingSuffix); }; - let langfn = get_lang(lang); + let langfn = get_lang(lang)?; println!("{prompt:?}"); @@ -64,7 +63,7 @@ pub(crate) async fn get_snippet( let mut parser = Parser::new(); parser.set_language(&langfn).unwrap(); - let source_code = std::fs::read_to_string("./example.go").unwrap(); + let source_code = snippet_request.body.as_str(); let source_bytes = source_code.as_bytes(); let tree = parser.parse(&source_code, None).unwrap(); let root_node = tree.root_node(); @@ -76,7 +75,7 @@ pub(crate) async fn get_snippet( .map(|v| { mutation::apply( langfn.clone(), - snippet_request.body.as_bytes(), + source_bytes, root_node, &appstate.v2_mutations_collection[*v], ) diff --git a/src/v2/errors.rs b/src/v2/errors.rs index ea349fa..66e8aab 100644 --- a/src/v2/errors.rs +++ b/src/v2/errors.rs @@ -13,6 +13,8 @@ pub enum GetError { MissingSuffix, #[display("failed to embed your prompt.")] EmbedFailed, + #[display("snippets were requested for an unknown language")] + UnknownLang } impl error::ResponseError for GetError { @@ -29,7 +31,7 @@ impl error::ResponseError for GetError { fn status_code(&self) -> StatusCode { match *self { Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR, - Self::MissingSuffix => StatusCode::BAD_REQUEST, + Self::MissingSuffix | Self::UnknownLang => StatusCode::BAD_REQUEST, Self::Busy => StatusCode::GATEWAY_TIMEOUT, } }