refactor: move common closest mutation search into a function
This commit is contained in:
101
src/main.rs
101
src/main.rs
@@ -12,8 +12,7 @@ use std::collections::HashMap;
|
||||
use tokio::sync::Mutex;
|
||||
use tower_lsp::lsp_types::*;
|
||||
use tower_lsp::{Client, LanguageServer, LspService, Server};
|
||||
use tracing::{error, info};
|
||||
use tree_sitter::Parser as TSParser;
|
||||
use tracing::error;
|
||||
|
||||
mod embed;
|
||||
mod state;
|
||||
@@ -267,6 +266,7 @@ impl LanguageServer for Backend {
|
||||
params: CodeActionParams,
|
||||
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
|
||||
let uri = params.text_document.uri;
|
||||
let extension = url_extension(&uri);
|
||||
let body = self.body.lock().await.to_string();
|
||||
|
||||
let range = params.range;
|
||||
@@ -278,93 +278,50 @@ impl LanguageServer for Backend {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some((prompt, lang)) = desc.rsplit_once(" in ") else {
|
||||
let (prompt, lang) = if let Some(ext) = extension {
|
||||
(desc, ext)
|
||||
} else if let Some((prompt, lang)) = desc.rsplit_once(" in ") {
|
||||
(prompt, lang.to_string())
|
||||
} else {
|
||||
error!("{}", v2::errors::Error::MissingSuffix);
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let langfn = match v2::api::get_lang(lang) {
|
||||
Ok(o) => o,
|
||||
Err(e) => {
|
||||
error!("{e}");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
info!(prompt = prompt, language = lang, "v2 request");
|
||||
|
||||
let mut appstate = self
|
||||
.appstate
|
||||
.inner
|
||||
.lock()
|
||||
.map_err(|_| v2::errors::Error::Busy)
|
||||
.expect("booo");
|
||||
let target = appstate
|
||||
.embed
|
||||
.embed(prompt)
|
||||
.map_err(|_| v2::errors::Error::EmbedFailed)
|
||||
.expect("booo");
|
||||
let mut parser = TSParser::new();
|
||||
parser
|
||||
.set_language(&langfn)
|
||||
.map_err(|_| v2::errors::Error::UnknownLang)
|
||||
.expect("boo");
|
||||
|
||||
let source_code = new_text;
|
||||
let source_bytes = source_code.as_bytes();
|
||||
let tree = parser
|
||||
.parse(source_code, None)
|
||||
.ok_or(v2::errors::Error::SnippetParsing)
|
||||
.expect("boo");
|
||||
let root_node = tree.root_node();
|
||||
|
||||
// search for k nearest neighbors
|
||||
let closest: Vec<String> = appstate.v2.dict[lang]
|
||||
.search(&target, 1)
|
||||
.iter()
|
||||
.filter_map(|&index| {
|
||||
let applied = v2::mutation::apply(
|
||||
langfn.clone(),
|
||||
source_bytes,
|
||||
root_node,
|
||||
&appstate.v2.mutations_collection[index],
|
||||
);
|
||||
match applied {
|
||||
Ok(v) => Some(v),
|
||||
Err(e) => {
|
||||
error!(
|
||||
collection_index = index,
|
||||
"failed to apply mutations from collection {}", e
|
||||
);
|
||||
None
|
||||
}
|
||||
let closest_matches =
|
||||
match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("{}", e);
|
||||
return Ok(None);
|
||||
}
|
||||
// TODO: change the expect to a log
|
||||
})
|
||||
.collect();
|
||||
|
||||
let closest = closest[0].clone();
|
||||
};
|
||||
|
||||
let Some(closest) = closest_matches.into_iter().next() else {
|
||||
return Ok(None);
|
||||
};
|
||||
let text_edit = TextEdit {
|
||||
range,
|
||||
new_text: closest,
|
||||
};
|
||||
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
|
||||
let edit = WorkspaceEdit {
|
||||
let edit = Some(WorkspaceEdit {
|
||||
changes: Some(changes),
|
||||
document_changes: None,
|
||||
change_annotations: None,
|
||||
};
|
||||
});
|
||||
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
|
||||
title: "ask silos".to_string(),
|
||||
kind: None,
|
||||
diagnostics: None,
|
||||
edit: Some(edit),
|
||||
command: None,
|
||||
is_preferred: None,
|
||||
disabled: None,
|
||||
data: None,
|
||||
edit,
|
||||
..Default::default()
|
||||
})];
|
||||
Ok(Some(actions))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn url_extension(u: &Url) -> Option<String> {
|
||||
let file_path = u.to_file_path().ok()?;
|
||||
|
||||
let extension = file_path.extension()?;
|
||||
let extension = extension.to_str()?;
|
||||
Some(extension.to_string())
|
||||
}
|
||||
|
||||
@@ -52,6 +52,23 @@ pub(crate) async fn get_snippet(
|
||||
return Err(Error::MissingSuffix);
|
||||
};
|
||||
|
||||
let closest = closest_mutation(
|
||||
lang,
|
||||
prompt,
|
||||
snippet_request.body.as_str(),
|
||||
snippet_request.top_k.unwrap_or(1),
|
||||
&data,
|
||||
)?;
|
||||
Ok(web::Json(closest))
|
||||
}
|
||||
|
||||
pub fn closest_mutation(
|
||||
lang: &str,
|
||||
prompt: &str,
|
||||
body: &str,
|
||||
top_k: usize,
|
||||
data: &web::Data<crate::state::StateWrapper>,
|
||||
) -> Result<Vec<String>, Error> {
|
||||
let langfn = get_lang(lang)?;
|
||||
|
||||
info!(prompt = prompt, language = lang, "v2 request");
|
||||
@@ -66,7 +83,7 @@ pub(crate) async fn get_snippet(
|
||||
.set_language(&langfn)
|
||||
.map_err(|_| Error::UnknownLang)?;
|
||||
|
||||
let source_code = snippet_request.body.as_str();
|
||||
let source_code = body;
|
||||
let source_bytes = source_code.as_bytes();
|
||||
let tree = parser
|
||||
.parse(source_code, None)
|
||||
@@ -74,8 +91,8 @@ pub(crate) async fn get_snippet(
|
||||
let root_node = tree.root_node();
|
||||
|
||||
// search for k nearest neighbors
|
||||
let closest: Vec<String> = appstate.v2.dict[lang]
|
||||
.search(&target, snippet_request.top_k.unwrap_or(1))
|
||||
let collected = appstate.v2.dict[lang]
|
||||
.search(&target, top_k)
|
||||
.iter()
|
||||
.filter_map(|&index| {
|
||||
let applied = mutation::apply(
|
||||
@@ -97,5 +114,5 @@ pub(crate) async fn get_snippet(
|
||||
// TODO: change the expect to a log
|
||||
})
|
||||
.collect();
|
||||
Ok(web::Json(closest))
|
||||
Ok(collected)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user