feat: v2 parses all kdl definitions via globs
This commit is contained in:
45
src/main.rs
45
src/main.rs
@@ -90,20 +90,39 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
// v2 stuff
|
||||
let mutations = v2::mutation::from_path("snippets/v2/go/mutations.kdl")?;
|
||||
let mut v2_dict = HashMap::new();
|
||||
let dimension = 384;
|
||||
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
||||
let paths = glob::glob("./snippets/v2/*/*.kdl")?;
|
||||
let mut v2_dict = HashMap::new();
|
||||
let mut v2_mutations_collection = vec![];
|
||||
for (i, path) in paths.enumerate() {
|
||||
let path = path?;
|
||||
let parent = path
|
||||
.components()
|
||||
.rev()
|
||||
.nth(1)
|
||||
.unwrap()
|
||||
.as_os_str()
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
let mut v2_index = HNSWIndex::<f32, usize>::new(dimension, ¶ms);
|
||||
v2_index
|
||||
.add(&embed.embed(&mutations.description)?, 0)
|
||||
.map_err(E::msg)?;
|
||||
v2_index
|
||||
.build(hora::core::metrics::Metric::Euclidean)
|
||||
.map_err(E::msg)?;
|
||||
let v2_mutations_collection = vec![mutations];
|
||||
v2_dict.insert("go".to_string(), v2_index);
|
||||
let mutations = v2::mutation::from_path(path)?;
|
||||
let current_lang_index = v2_dict.entry(parent).or_insert_with(|| {
|
||||
let dimension = 384;
|
||||
let params = hora::index::hnsw_params::HNSWParams::<f32>::default();
|
||||
|
||||
HNSWIndex::<f32, usize>::new(dimension, ¶ms)
|
||||
});
|
||||
|
||||
current_lang_index
|
||||
.add(&embed.embed(&mutations.description)?, i)
|
||||
.map_err(E::msg)?;
|
||||
v2_mutations_collection.push(mutations);
|
||||
}
|
||||
|
||||
for index in v2_dict.values_mut() {
|
||||
index
|
||||
.build(hora::core::metrics::Metric::Euclidean)
|
||||
.map_err(E::msg)?;
|
||||
}
|
||||
|
||||
let appstate = v1::api::AppState {
|
||||
dict,
|
||||
|
||||
@@ -3,26 +3,17 @@ use std::{collections::HashMap, sync::Mutex};
|
||||
|
||||
use hora::index::hnsw_idx::HNSWIndex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{embed, v2::mutation::MutationCollection};
|
||||
|
||||
use super::errors::GetError;
|
||||
|
||||
use anyhow::Result;
|
||||
use actix_web::{Responder, post, web};
|
||||
|
||||
use anyhow::Result;
|
||||
#[derive(Deserialize)]
|
||||
pub struct SnippetRequest {
|
||||
desc: String,
|
||||
top_k: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct SnippetOnDisk {
|
||||
pub body: String,
|
||||
pub desc: String,
|
||||
}
|
||||
|
||||
pub struct AppStateWrapper {
|
||||
pub inner: Mutex<AppState>,
|
||||
}
|
||||
|
||||
@@ -4,9 +4,8 @@ use tree_sitter::Parser;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{embed, v2::mutation};
|
||||
|
||||
use super::{errors::GetError, mutation::Mutation};
|
||||
use super::{errors::GetError, mutation};
|
||||
|
||||
use actix_web::{Responder, post, web};
|
||||
|
||||
@@ -31,13 +30,13 @@ pub struct Snippet {
|
||||
body: String,
|
||||
}
|
||||
|
||||
fn get_lang(s: &str) -> tree_sitter::Language {
|
||||
match s {
|
||||
fn get_lang(s: &str) -> Result<tree_sitter::Language, GetError> {
|
||||
Ok(match s {
|
||||
"go" => tree_sitter_go::LANGUAGE,
|
||||
"rust" => tree_sitter_rust::LANGUAGE,
|
||||
_ => unreachable!(),
|
||||
_ => return Err(GetError::UnknownLang),
|
||||
}
|
||||
.into()
|
||||
.into())
|
||||
}
|
||||
|
||||
#[post("/api/v2/get")]
|
||||
@@ -49,7 +48,7 @@ pub(crate) async fn get_snippet(
|
||||
return Err(GetError::MissingSuffix);
|
||||
};
|
||||
|
||||
let langfn = get_lang(lang);
|
||||
let langfn = get_lang(lang)?;
|
||||
|
||||
println!("{prompt:?}");
|
||||
|
||||
@@ -64,7 +63,7 @@ pub(crate) async fn get_snippet(
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(&langfn).unwrap();
|
||||
|
||||
let source_code = std::fs::read_to_string("./example.go").unwrap();
|
||||
let source_code = snippet_request.body.as_str();
|
||||
let source_bytes = source_code.as_bytes();
|
||||
let tree = parser.parse(&source_code, None).unwrap();
|
||||
let root_node = tree.root_node();
|
||||
@@ -76,7 +75,7 @@ pub(crate) async fn get_snippet(
|
||||
.map(|v| {
|
||||
mutation::apply(
|
||||
langfn.clone(),
|
||||
snippet_request.body.as_bytes(),
|
||||
source_bytes,
|
||||
root_node,
|
||||
&appstate.v2_mutations_collection[*v],
|
||||
)
|
||||
|
||||
@@ -13,6 +13,8 @@ pub enum GetError {
|
||||
MissingSuffix,
|
||||
#[display("failed to embed your prompt.")]
|
||||
EmbedFailed,
|
||||
#[display("snippets were requested for an unknown language")]
|
||||
UnknownLang
|
||||
}
|
||||
|
||||
impl error::ResponseError for GetError {
|
||||
@@ -29,7 +31,7 @@ impl error::ResponseError for GetError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match *self {
|
||||
Self::EmbedFailed => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::MissingSuffix => StatusCode::BAD_REQUEST,
|
||||
Self::MissingSuffix | Self::UnknownLang => StatusCode::BAD_REQUEST,
|
||||
Self::Busy => StatusCode::GATEWAY_TIMEOUT,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user