feat: v2 parses all kdl definitions via globs

This commit is contained in:
Himadri Bhattacharjee
2025-06-20 08:10:33 +05:30
parent 7ca19deff6
commit 90e3983f34
4 changed files with 44 additions and 33 deletions

View File

@@ -90,20 +90,39 @@ async fn main() -> Result<()> {
}
// v2 stuff
let mutations = v2::mutation::from_path("snippets/v2/go/mutations.kdl")?;
let mut v2_dict = HashMap::new();
let dimension = 384;
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
let paths = glob::glob("./snippets/v2/*/*.kdl")?;
let mut v2_dict = HashMap::new();
let mut v2_mutations_collection = vec![];
for (i, path) in paths.enumerate() {
let path = path?;
let parent = path
.components()
.rev()
.nth(1)
.unwrap()
.as_os_str()
.to_string_lossy()
.to_string();
let mut v2_index = HNSWIndex::<f32, usize>::new(dimension, &params);
v2_index
.add(&embed.embed(&mutations.description)?, 0)
.map_err(E::msg)?;
v2_index
.build(hora::core::metrics::Metric::Euclidean)
.map_err(E::msg)?;
let v2_mutations_collection = vec![mutations];
v2_dict.insert("go".to_string(), v2_index);
let mutations = v2::mutation::from_path(path)?;
let current_lang_index = v2_dict.entry(parent).or_insert_with(|| {
let dimension = 384;
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
HNSWIndex::<f32, usize>::new(dimension, &params)
});
current_lang_index
.add(&embed.embed(&mutations.description)?, i)
.map_err(E::msg)?;
v2_mutations_collection.push(mutations);
}
for index in v2_dict.values_mut() {
index
.build(hora::core::metrics::Metric::Euclidean)
.map_err(E::msg)?;
}
let appstate = v1::api::AppState {
dict,

View File

@@ -3,26 +3,17 @@ use std::{collections::HashMap, sync::Mutex};
use hora::index::hnsw_idx::HNSWIndex;
use serde::{Deserialize, Serialize};
use crate::{embed, v2::mutation::MutationCollection};
use super::errors::GetError;
use anyhow::Result;
use actix_web::{Responder, post, web};
use anyhow::Result;
#[derive(Deserialize)]
pub struct SnippetRequest {
desc: String,
top_k: Option<usize>,
}
#[derive(Deserialize, Debug)]
pub struct SnippetOnDisk {
pub body: String,
pub desc: String,
}
pub struct AppStateWrapper {
pub inner: Mutex<AppState>,
}

View File

@@ -4,9 +4,8 @@ use tree_sitter::Parser;
use serde::{Deserialize, Serialize};
use crate::{embed, v2::mutation};
use super::{errors::GetError, mutation::Mutation};
use super::{errors::GetError, mutation};
use actix_web::{Responder, post, web};
@@ -31,13 +30,13 @@ pub struct Snippet {
body: String,
}
fn get_lang(s: &str) -> tree_sitter::Language {
match s {
fn get_lang(s: &str) -> Result<tree_sitter::Language, GetError> {
Ok(match s {
"go" => tree_sitter_go::LANGUAGE,
"rust" => tree_sitter_rust::LANGUAGE,
_ => unreachable!(),
_ => return Err(GetError::UnknownLang),
}
.into()
.into())
}
#[post("/api/v2/get")]
@@ -49,7 +48,7 @@ pub(crate) async fn get_snippet(
return Err(GetError::MissingSuffix);
};
let langfn = get_lang(lang);
let langfn = get_lang(lang)?;
println!("{prompt:?}");
@@ -64,7 +63,7 @@ pub(crate) async fn get_snippet(
let mut parser = Parser::new();
parser.set_language(&langfn).unwrap();
let source_code = std::fs::read_to_string("./example.go").unwrap();
let source_code = snippet_request.body.as_str();
let source_bytes = source_code.as_bytes();
let tree = parser.parse(&source_code, None).unwrap();
let root_node = tree.root_node();
@@ -76,7 +75,7 @@ pub(crate) async fn get_snippet(
.map(|v| {
mutation::apply(
langfn.clone(),
snippet_request.body.as_bytes(),
source_bytes,
root_node,
&appstate.v2_mutations_collection[*v],
)

View File

@@ -13,6 +13,8 @@ pub enum GetError {
MissingSuffix,
#[display("failed to embed your prompt.")]
EmbedFailed,
#[display("snippets were requested for an unknown language")]
UnknownLang
}
impl error::ResponseError for GetError {
@@ -29,7 +31,7 @@ impl error::ResponseError for GetError {
fn status_code(&self) -> StatusCode {
match *self {
Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR,
Self::MissingSuffix => StatusCode::BAD_REQUEST,
Self::MissingSuffix | Self::UnknownLang => StatusCode::BAD_REQUEST,
Self::Busy => StatusCode::GATEWAY_TIMEOUT,
}
}