refactor: move common closest mutation search into a function

This commit is contained in:
Himadri Bhattacharjee
2025-07-01 18:58:18 +05:30
parent 4b710e8675
commit 716b9ed3e2
2 changed files with 50 additions and 76 deletions

View File

@@ -12,8 +12,7 @@ use std::collections::HashMap;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower_lsp::lsp_types::*; use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer, LspService, Server}; use tower_lsp::{Client, LanguageServer, LspService, Server};
use tracing::{error, info}; use tracing::error;
use tree_sitter::Parser as TSParser;
mod embed; mod embed;
mod state; mod state;
@@ -267,6 +266,7 @@ impl LanguageServer for Backend {
params: CodeActionParams, params: CodeActionParams,
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> { ) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
let uri = params.text_document.uri; let uri = params.text_document.uri;
let extension = url_extension(&uri);
let body = self.body.lock().await.to_string(); let body = self.body.lock().await.to_string();
let range = params.range; let range = params.range;
@@ -278,93 +278,50 @@ impl LanguageServer for Backend {
return Ok(None); return Ok(None);
}; };
let Some((prompt, lang)) = desc.rsplit_once(" in ") else { let (prompt, lang) = if let Some(ext) = extension {
(desc, ext)
} else if let Some((prompt, lang)) = desc.rsplit_once(" in ") {
(prompt, lang.to_string())
} else {
error!("{}", v2::errors::Error::MissingSuffix); error!("{}", v2::errors::Error::MissingSuffix);
return Ok(None); return Ok(None);
}; };
let langfn = match v2::api::get_lang(lang) { let closest_matches =
Ok(o) => o, match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) {
Err(e) => { Ok(v) => v,
error!("{e}"); Err(e) => {
return Ok(None); error!("{}", e);
} return Ok(None);
};
info!(prompt = prompt, language = lang, "v2 request");
let mut appstate = self
.appstate
.inner
.lock()
.map_err(|_| v2::errors::Error::Busy)
.expect("booo");
let target = appstate
.embed
.embed(prompt)
.map_err(|_| v2::errors::Error::EmbedFailed)
.expect("booo");
let mut parser = TSParser::new();
parser
.set_language(&langfn)
.map_err(|_| v2::errors::Error::UnknownLang)
.expect("boo");
let source_code = new_text;
let source_bytes = source_code.as_bytes();
let tree = parser
.parse(source_code, None)
.ok_or(v2::errors::Error::SnippetParsing)
.expect("boo");
let root_node = tree.root_node();
// search for k nearest neighbors
let closest: Vec<String> = appstate.v2.dict[lang]
.search(&target, 1)
.iter()
.filter_map(|&index| {
let applied = v2::mutation::apply(
langfn.clone(),
source_bytes,
root_node,
&appstate.v2.mutations_collection[index],
);
match applied {
Ok(v) => Some(v),
Err(e) => {
error!(
collection_index = index,
"failed to apply mutations from collection {}", e
);
None
}
} }
// TODO: change the expect to a log };
})
.collect();
let closest = closest[0].clone();
let Some(closest) = closest_matches.into_iter().next() else {
return Ok(None);
};
let text_edit = TextEdit { let text_edit = TextEdit {
range, range,
new_text: closest, new_text: closest,
}; };
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect(); let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
let edit = WorkspaceEdit { let edit = Some(WorkspaceEdit {
changes: Some(changes), changes: Some(changes),
document_changes: None, document_changes: None,
change_annotations: None, change_annotations: None,
}; });
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction { let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
title: "ask silos".to_string(), title: "ask silos".to_string(),
kind: None, edit,
diagnostics: None, ..Default::default()
edit: Some(edit),
command: None,
is_preferred: None,
disabled: None,
data: None,
})]; })];
Ok(Some(actions)) Ok(Some(actions))
} }
} }
pub fn url_extension(u: &Url) -> Option<String> {
let file_path = u.to_file_path().ok()?;
let extension = file_path.extension()?;
let extension = extension.to_str()?;
Some(extension.to_string())
}

View File

@@ -52,6 +52,23 @@ pub(crate) async fn get_snippet(
return Err(Error::MissingSuffix); return Err(Error::MissingSuffix);
}; };
let closest = closest_mutation(
lang,
prompt,
snippet_request.body.as_str(),
snippet_request.top_k.unwrap_or(1),
&data,
)?;
Ok(web::Json(closest))
}
pub fn closest_mutation(
lang: &str,
prompt: &str,
body: &str,
top_k: usize,
data: &web::Data<crate::state::StateWrapper>,
) -> Result<Vec<String>, Error> {
let langfn = get_lang(lang)?; let langfn = get_lang(lang)?;
info!(prompt = prompt, language = lang, "v2 request"); info!(prompt = prompt, language = lang, "v2 request");
@@ -66,7 +83,7 @@ pub(crate) async fn get_snippet(
.set_language(&langfn) .set_language(&langfn)
.map_err(|_| Error::UnknownLang)?; .map_err(|_| Error::UnknownLang)?;
let source_code = snippet_request.body.as_str(); let source_code = body;
let source_bytes = source_code.as_bytes(); let source_bytes = source_code.as_bytes();
let tree = parser let tree = parser
.parse(source_code, None) .parse(source_code, None)
@@ -74,8 +91,8 @@ pub(crate) async fn get_snippet(
let root_node = tree.root_node(); let root_node = tree.root_node();
// search for k nearest neighbors // search for k nearest neighbors
let closest: Vec<String> = appstate.v2.dict[lang] let collected = appstate.v2.dict[lang]
.search(&target, snippet_request.top_k.unwrap_or(1)) .search(&target, top_k)
.iter() .iter()
.filter_map(|&index| { .filter_map(|&index| {
let applied = mutation::apply( let applied = mutation::apply(
@@ -97,5 +114,5 @@ pub(crate) async fn get_snippet(
// TODO: change the expect to a log // TODO: change the expect to a log
}) })
.collect(); .collect();
Ok(web::Json(closest)) Ok(collected)
} }