refactor: move common closest mutation search into a function
This commit is contained in:
101
src/main.rs
101
src/main.rs
@@ -12,8 +12,7 @@ use std::collections::HashMap;
|
|||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tower_lsp::lsp_types::*;
|
use tower_lsp::lsp_types::*;
|
||||||
use tower_lsp::{Client, LanguageServer, LspService, Server};
|
use tower_lsp::{Client, LanguageServer, LspService, Server};
|
||||||
use tracing::{error, info};
|
use tracing::error;
|
||||||
use tree_sitter::Parser as TSParser;
|
|
||||||
|
|
||||||
mod embed;
|
mod embed;
|
||||||
mod state;
|
mod state;
|
||||||
@@ -267,6 +266,7 @@ impl LanguageServer for Backend {
|
|||||||
params: CodeActionParams,
|
params: CodeActionParams,
|
||||||
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
|
) -> tower_lsp::jsonrpc::Result<Option<CodeActionResponse>> {
|
||||||
let uri = params.text_document.uri;
|
let uri = params.text_document.uri;
|
||||||
|
let extension = url_extension(&uri);
|
||||||
let body = self.body.lock().await.to_string();
|
let body = self.body.lock().await.to_string();
|
||||||
|
|
||||||
let range = params.range;
|
let range = params.range;
|
||||||
@@ -278,93 +278,50 @@ impl LanguageServer for Backend {
|
|||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some((prompt, lang)) = desc.rsplit_once(" in ") else {
|
let (prompt, lang) = if let Some(ext) = extension {
|
||||||
|
(desc, ext)
|
||||||
|
} else if let Some((prompt, lang)) = desc.rsplit_once(" in ") {
|
||||||
|
(prompt, lang.to_string())
|
||||||
|
} else {
|
||||||
error!("{}", v2::errors::Error::MissingSuffix);
|
error!("{}", v2::errors::Error::MissingSuffix);
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
let langfn = match v2::api::get_lang(lang) {
|
let closest_matches =
|
||||||
Ok(o) => o,
|
match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) {
|
||||||
Err(e) => {
|
Ok(v) => v,
|
||||||
error!("{e}");
|
Err(e) => {
|
||||||
return Ok(None);
|
error!("{}", e);
|
||||||
}
|
return Ok(None);
|
||||||
};
|
|
||||||
|
|
||||||
info!(prompt = prompt, language = lang, "v2 request");
|
|
||||||
|
|
||||||
let mut appstate = self
|
|
||||||
.appstate
|
|
||||||
.inner
|
|
||||||
.lock()
|
|
||||||
.map_err(|_| v2::errors::Error::Busy)
|
|
||||||
.expect("booo");
|
|
||||||
let target = appstate
|
|
||||||
.embed
|
|
||||||
.embed(prompt)
|
|
||||||
.map_err(|_| v2::errors::Error::EmbedFailed)
|
|
||||||
.expect("booo");
|
|
||||||
let mut parser = TSParser::new();
|
|
||||||
parser
|
|
||||||
.set_language(&langfn)
|
|
||||||
.map_err(|_| v2::errors::Error::UnknownLang)
|
|
||||||
.expect("boo");
|
|
||||||
|
|
||||||
let source_code = new_text;
|
|
||||||
let source_bytes = source_code.as_bytes();
|
|
||||||
let tree = parser
|
|
||||||
.parse(source_code, None)
|
|
||||||
.ok_or(v2::errors::Error::SnippetParsing)
|
|
||||||
.expect("boo");
|
|
||||||
let root_node = tree.root_node();
|
|
||||||
|
|
||||||
// search for k nearest neighbors
|
|
||||||
let closest: Vec<String> = appstate.v2.dict[lang]
|
|
||||||
.search(&target, 1)
|
|
||||||
.iter()
|
|
||||||
.filter_map(|&index| {
|
|
||||||
let applied = v2::mutation::apply(
|
|
||||||
langfn.clone(),
|
|
||||||
source_bytes,
|
|
||||||
root_node,
|
|
||||||
&appstate.v2.mutations_collection[index],
|
|
||||||
);
|
|
||||||
match applied {
|
|
||||||
Ok(v) => Some(v),
|
|
||||||
Err(e) => {
|
|
||||||
error!(
|
|
||||||
collection_index = index,
|
|
||||||
"failed to apply mutations from collection {}", e
|
|
||||||
);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// TODO: change the expect to a log
|
};
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let closest = closest[0].clone();
|
|
||||||
|
|
||||||
|
let Some(closest) = closest_matches.into_iter().next() else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
let text_edit = TextEdit {
|
let text_edit = TextEdit {
|
||||||
range,
|
range,
|
||||||
new_text: closest,
|
new_text: closest,
|
||||||
};
|
};
|
||||||
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
|
let changes: HashMap<Url, _> = [(uri, vec![text_edit])].into_iter().collect();
|
||||||
let edit = WorkspaceEdit {
|
let edit = Some(WorkspaceEdit {
|
||||||
changes: Some(changes),
|
changes: Some(changes),
|
||||||
document_changes: None,
|
document_changes: None,
|
||||||
change_annotations: None,
|
change_annotations: None,
|
||||||
};
|
});
|
||||||
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
|
let actions = vec![CodeActionOrCommand::CodeAction(CodeAction {
|
||||||
title: "ask silos".to_string(),
|
title: "ask silos".to_string(),
|
||||||
kind: None,
|
edit,
|
||||||
diagnostics: None,
|
..Default::default()
|
||||||
edit: Some(edit),
|
|
||||||
command: None,
|
|
||||||
is_preferred: None,
|
|
||||||
disabled: None,
|
|
||||||
data: None,
|
|
||||||
})];
|
})];
|
||||||
Ok(Some(actions))
|
Ok(Some(actions))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn url_extension(u: &Url) -> Option<String> {
|
||||||
|
let file_path = u.to_file_path().ok()?;
|
||||||
|
|
||||||
|
let extension = file_path.extension()?;
|
||||||
|
let extension = extension.to_str()?;
|
||||||
|
Some(extension.to_string())
|
||||||
|
}
|
||||||
|
|||||||
@@ -52,6 +52,23 @@ pub(crate) async fn get_snippet(
|
|||||||
return Err(Error::MissingSuffix);
|
return Err(Error::MissingSuffix);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let closest = closest_mutation(
|
||||||
|
lang,
|
||||||
|
prompt,
|
||||||
|
snippet_request.body.as_str(),
|
||||||
|
snippet_request.top_k.unwrap_or(1),
|
||||||
|
&data,
|
||||||
|
)?;
|
||||||
|
Ok(web::Json(closest))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn closest_mutation(
|
||||||
|
lang: &str,
|
||||||
|
prompt: &str,
|
||||||
|
body: &str,
|
||||||
|
top_k: usize,
|
||||||
|
data: &web::Data<crate::state::StateWrapper>,
|
||||||
|
) -> Result<Vec<String>, Error> {
|
||||||
let langfn = get_lang(lang)?;
|
let langfn = get_lang(lang)?;
|
||||||
|
|
||||||
info!(prompt = prompt, language = lang, "v2 request");
|
info!(prompt = prompt, language = lang, "v2 request");
|
||||||
@@ -66,7 +83,7 @@ pub(crate) async fn get_snippet(
|
|||||||
.set_language(&langfn)
|
.set_language(&langfn)
|
||||||
.map_err(|_| Error::UnknownLang)?;
|
.map_err(|_| Error::UnknownLang)?;
|
||||||
|
|
||||||
let source_code = snippet_request.body.as_str();
|
let source_code = body;
|
||||||
let source_bytes = source_code.as_bytes();
|
let source_bytes = source_code.as_bytes();
|
||||||
let tree = parser
|
let tree = parser
|
||||||
.parse(source_code, None)
|
.parse(source_code, None)
|
||||||
@@ -74,8 +91,8 @@ pub(crate) async fn get_snippet(
|
|||||||
let root_node = tree.root_node();
|
let root_node = tree.root_node();
|
||||||
|
|
||||||
// search for k nearest neighbors
|
// search for k nearest neighbors
|
||||||
let closest: Vec<String> = appstate.v2.dict[lang]
|
let collected = appstate.v2.dict[lang]
|
||||||
.search(&target, snippet_request.top_k.unwrap_or(1))
|
.search(&target, top_k)
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|&index| {
|
.filter_map(|&index| {
|
||||||
let applied = mutation::apply(
|
let applied = mutation::apply(
|
||||||
@@ -97,5 +114,5 @@ pub(crate) async fn get_snippet(
|
|||||||
// TODO: change the expect to a log
|
// TODO: change the expect to a log
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
Ok(web::Json(closest))
|
Ok(collected)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user