refactor: move common closest mutation search into a function

This commit is contained in:
Himadri Bhattacharjee
2025-07-01 18:58:18 +05:30
parent 4b710e8675
commit 716b9ed3e2
2 changed files with 50 additions and 76 deletions

View File

@@ -12,8 +12,7 @@ use std::collections::HashMap;
use tokio::sync::Mutex;
use tower_lsp::lsp_types::*;
use tower_lsp::{Client, LanguageServer, LspService, Server};
use tracing::{error, info};
use tree_sitter::Parser as TSParser;
use tracing::error;
mod embed;
mod state;
@@ -267,6 +266,7 @@ impl LanguageServer for Backend {
params: CodeActionParams,
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
let uri = params.text_document.uri;
let extension = url_extension(&uri);
let body = self.body.lock().await.to_string();
let range = params.range;
@@ -278,93 +278,50 @@ impl LanguageServer for Backend {
return Ok(None);
};
let Some((prompt, lang)) = desc.rsplit_once(" in ") else {
let (prompt, lang) = if let Some(ext) = extension {
(desc, ext)
} else if let Some((prompt, lang)) = desc.rsplit_once(" in ") {
(prompt, lang.to_string())
} else {
error!("{}", v2::errors::Error::MissingSuffix);
return Ok(None);
};
let langfn = match v2::api::get_lang(lang) {
Ok(o) => o,
Err(e) => {
error!("{e}");
return Ok(None);
}
};
info!(prompt = prompt, language = lang, "v2 request");
let mut appstate = self
.appstate
.inner
.lock()
.map_err(|_| v2::errors::Error::Busy)
.expect("booo");
let target = appstate
.embed
.embed(prompt)
.map_err(|_| v2::errors::Error::EmbedFailed)
.expect("booo");
let mut parser = TSParser::new();
parser
.set_language(&langfn)
.map_err(|_| v2::errors::Error::UnknownLang)
.expect("boo");
let source_code = new_text;
let source_bytes = source_code.as_bytes();
let tree = parser
.parse(source_code, None)
.ok_or(v2::errors::Error::SnippetParsing)
.expect("boo");
let root_node = tree.root_node();
// search for k nearest neighbors
let closest: Vec<String> = appstate.v2.dict[lang]
.search(&target, 1)
.iter()
.filter_map(|&index| {
let applied = v2::mutation::apply(
langfn.clone(),
source_bytes,
root_node,
&appstate.v2.mutations_collection[index],
);
match applied {
Ok(v) => Some(v),
Err(e) => {
error!(
collection_index = index,
"failed to apply mutations from collection {}", e
);
None
}
let closest_matches =
match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) {
Ok(v) => v,
Err(e) => {
error!("{}", e);
return Ok(None);
}
// TODO: change the expect to a log
})
.collect();
let closest = closest[0].clone();
};
let Some(closest) = closest_matches.into_iter().next() else {
return Ok(None);
};
let text_edit = TextEdit {
range,
new_text: closest,
};
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
let edit = WorkspaceEdit {
let edit = Some(WorkspaceEdit {
changes: Some(changes),
document_changes: None,
change_annotations: None,
};
});
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
title: "ask silos".to_string(),
kind: None,
diagnostics: None,
edit: Some(edit),
command: None,
is_preferred: None,
disabled: None,
data: None,
edit,
..Default::default()
})];
Ok(Some(actions))
}
}
pub fn url_extension(u: &Url) -> Option<String> {
let file_path = u.to_file_path().ok()?;
let extension = file_path.extension()?;
let extension = extension.to_str()?;
Some(extension.to_string())
}

View File

@@ -52,6 +52,23 @@ pub(crate) async fn get_snippet(
return Err(Error::MissingSuffix);
};
let closest = closest_mutation(
lang,
prompt,
snippet_request.body.as_str(),
snippet_request.top_k.unwrap_or(1),
&data,
)?;
Ok(web::Json(closest))
}
pub fn closest_mutation(
lang: &str,
prompt: &str,
body: &str,
top_k: usize,
data: &web::Data<crate::state::StateWrapper>,
) -> Result<Vec<String>, Error> {
let langfn = get_lang(lang)?;
info!(prompt = prompt, language = lang, "v2 request");
@@ -66,7 +83,7 @@ pub(crate) async fn get_snippet(
.set_language(&langfn)
.map_err(|_| Error::UnknownLang)?;
let source_code = snippet_request.body.as_str();
let source_code = body;
let source_bytes = source_code.as_bytes();
let tree = parser
.parse(source_code, None)
@@ -74,8 +91,8 @@ pub(crate) async fn get_snippet(
let root_node = tree.root_node();
// search for k nearest neighbors
let closest: Vec<String> = appstate.v2.dict[lang]
.search(&target, snippet_request.top_k.unwrap_or(1))
let collected = appstate.v2.dict[lang]
.search(&target, top_k)
.iter()
.filter_map(|&index| {
let applied = mutation::apply(
@@ -97,5 +114,5 @@ pub(crate) async fn get_snippet(
// TODO: change the expect to a log
})
.collect();
Ok(web::Json(closest))
Ok(collected)
}