diff --git a/src/main.rs b/src/main.rs index 2b2ad67..75a9f6c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,8 +12,7 @@ use std::collections::HashMap; use tokio::sync::Mutex; use tower_lsp::lsp_types::*; use tower_lsp::{Client, LanguageServer, LspService, Server}; -use tracing::{error, info}; -use tree_sitter::Parser as TSParser; +use tracing::error; mod embed; mod state; @@ -267,6 +266,7 @@ impl LanguageServer for Backend { params: CodeActionParams, ) -> tower_lsp::jsonrpc::Result> { let uri = params.text_document.uri; + let extension = url_extension(&uri); let body = self.body.lock().await.to_string(); let range = params.range; @@ -278,93 +278,50 @@ impl LanguageServer for Backend { return Ok(None); }; - let Some((prompt, lang)) = desc.rsplit_once(" in ") else { + let (prompt, lang) = if let Some(ext) = extension { + (desc, ext) + } else if let Some((prompt, lang)) = desc.rsplit_once(" in ") { + (prompt, lang.to_string()) + } else { error!("{}", v2::errors::Error::MissingSuffix); return Ok(None); }; - let langfn = match v2::api::get_lang(lang) { - Ok(o) => o, - Err(e) => { - error!("{e}"); - return Ok(None); - } - }; - - info!(prompt = prompt, language = lang, "v2 request"); - - let mut appstate = self - .appstate - .inner - .lock() - .map_err(|_| v2::errors::Error::Busy) - .expect("booo"); - let target = appstate - .embed - .embed(prompt) - .map_err(|_| v2::errors::Error::EmbedFailed) - .expect("booo"); - let mut parser = TSParser::new(); - parser - .set_language(&langfn) - .map_err(|_| v2::errors::Error::UnknownLang) - .expect("boo"); - - let source_code = new_text; - let source_bytes = source_code.as_bytes(); - let tree = parser - .parse(source_code, None) - .ok_or(v2::errors::Error::SnippetParsing) - .expect("boo"); - let root_node = tree.root_node(); - - // search for k nearest neighbors - let closest: Vec = appstate.v2.dict[lang] - .search(&target, 1) - .iter() - .filter_map(|&index| { - let applied = v2::mutation::apply( - langfn.clone(), - source_bytes, - root_node, - &appstate.v2.mutations_collection[index], - ); - match applied { - Ok(v) => Some(v), - Err(e) => { - error!( - collection_index = index, - "failed to apply mutations from collection {}", e - ); - None - } + let closest_matches = + match v2::api::closest_mutation(&lang, prompt, &body, 1, &self.appstate) { + Ok(v) => v, + Err(e) => { + error!("{}", e); + return Ok(None); } - // TODO: change the expect to a log - }) - .collect(); - - let closest = closest[0].clone(); + }; + let Some(closest) = closest_matches.into_iter().next() else { + return Ok(None); + }; let text_edit = TextEdit { range, new_text: closest, }; let changes: HashMap = [(uri, vec![text_edit])].into_iter().collect(); - let edit = WorkspaceEdit { + let edit = Some(WorkspaceEdit { changes: Some(changes), document_changes: None, change_annotations: None, - }; + }); let actions = vec![CodeActionOrCommand::CodeAction(CodeAction { title: "ask silos".to_string(), - kind: None, - diagnostics: None, - edit: Some(edit), - command: None, - is_preferred: None, - disabled: None, - data: None, + edit, + ..Default::default() })]; Ok(Some(actions)) } } + +pub fn url_extension(u: &Url) -> Option { + let file_path = u.to_file_path().ok()?; + + let extension = file_path.extension()?; + let extension = extension.to_str()?; + Some(extension.to_string()) +} diff --git a/src/v2/api.rs b/src/v2/api.rs index cad0139..480a4af 100644 --- a/src/v2/api.rs +++ b/src/v2/api.rs @@ -52,6 +52,23 @@ pub(crate) async fn get_snippet( return Err(Error::MissingSuffix); }; + let closest = closest_mutation( + lang, + prompt, + snippet_request.body.as_str(), + snippet_request.top_k.unwrap_or(1), + &data, + )?; + Ok(web::Json(closest)) +} + +pub fn closest_mutation( + lang: &str, + prompt: &str, + body: &str, + top_k: usize, + data: &web::Data, +) -> Result, Error> { let langfn = get_lang(lang)?; info!(prompt = prompt, language = lang, "v2 request"); @@ -66,7 +83,7 @@ pub(crate) async fn get_snippet( .set_language(&langfn) .map_err(|_| Error::UnknownLang)?; - let source_code = snippet_request.body.as_str(); + let source_code = body; let source_bytes = source_code.as_bytes(); let tree = parser .parse(source_code, None) @@ -74,8 +91,8 @@ pub(crate) async fn get_snippet( let root_node = tree.root_node(); // search for k nearest neighbors - let closest: Vec = appstate.v2.dict[lang] - .search(&target, snippet_request.top_k.unwrap_or(1)) + let collected = appstate.v2.dict[lang] + .search(&target, top_k) .iter() .filter_map(|&index| { let applied = mutation::apply( @@ -97,5 +114,5 @@ pub(crate) async fn get_snippet( // TODO: change the expect to a log }) .collect(); - Ok(web::Json(closest)) + Ok(collected) }